mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #24266 from alexlyulkov:al/tf-argmax-default-dim
Added default dimension value to tensorflow ArgMax and ArgMin layers #24266 Added default dimension value to tensorflow ArgMax and ArgMin layers. Added exception when accessing layer's input with out of range index. Fixes https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=48452
This commit is contained in:
parent
4790a3732e
commit
1e54e56579
@ -2665,14 +2665,20 @@ void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::No
|
||||
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
|
||||
}
|
||||
|
||||
// ArgMin or ArgMax node
|
||||
void TFImporter::parseArg(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
|
||||
{
|
||||
const std::string& name = layer.name();
|
||||
const std::string& type = layer.op();
|
||||
|
||||
Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1);
|
||||
layerParams.set("axis", *dimension.ptr<int>());
|
||||
if (layer.input_size() < 2)
|
||||
layerParams.set("axis", 0); // default dimension is 0
|
||||
else
|
||||
{
|
||||
Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1);
|
||||
layerParams.set("axis", dimension.at<int>(0));
|
||||
}
|
||||
layerParams.set("op", type == "ArgMax" ? "max" : "min");
|
||||
layerParams.set("keepdims", false); //tensorflow doesn't have this atrr, the output's dims minus one(default);
|
||||
|
||||
@ -2866,6 +2872,7 @@ const tensorflow::TensorProto& TFImporter::getConstBlob(const tensorflow::NodeDe
|
||||
|
||||
if (input_blob_index == -1)
|
||||
CV_Error(Error::StsError, "Const input blob for weights not found");
|
||||
CV_CheckLT(input_blob_index, layer.input_size(), "Input index is out of range");
|
||||
|
||||
Pin kernel_inp = parsePin(layer.input(input_blob_index));
|
||||
if (const_layers.find(kernel_inp.name) == const_layers.end())
|
||||
|
Loading…
Reference in New Issue
Block a user