Merge pull request #21268 from pccvlab:tf_Arg

add argmax and argmin parsing for tensorflow

* add argmax and argmin for tf

* remove whitespace

* remove whitespace

* remove static_cast

Signed-off-by: Crayon-new <1349159541@qq.com>
This commit is contained in:
Gruhuang 2021-12-17 01:06:02 +08:00 committed by GitHub
parent f7aa91e660
commit b4bb98ea60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 0 deletions

View File

@ -599,6 +599,8 @@ private:
void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseSquare (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseArg (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
};
@ -677,6 +679,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation;
dispatch["ExpandDims"] = &TFImporter::parseExpandDims;
dispatch["Square"] = &TFImporter::parseSquare;
dispatch["ArgMax"] = dispatch["ArgMin"] = &TFImporter::parseArg;
return dispatch;
}
@ -2624,6 +2627,22 @@ void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::No
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
}
void TFImporter::parseArg(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
const std::string& type = layer.op();
Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1));
CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1);
layerParams.set("axis", *dimension.ptr<int>());
layerParams.set("op", type == "ArgMax" ? "max" : "min");
layerParams.set("keepdims", false); //tensorflow doesn't have this atrr, the output's dims minus one(default);
int id = dstNet.addLayer(name, "Arg", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
void TFImporter::parseCustomLayer(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
// Importer does not know how to map this TensorFlow's operation onto OpenCV's layer.

View File

@ -185,6 +185,14 @@ TEST_P(Test_TensorFlow_layers, reduce_sum_channel_keep_dims)
runTensorFlowNet("reduce_sum_channel", false, 0.0, 0.0, false, "_keep_dims");
}
TEST_P(Test_TensorFlow_layers, ArgLayer)
{
if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU)
throw SkipTestException("Only CPU is supported"); // FIXIT use tags
runTensorFlowNet("argmax");
runTensorFlowNet("argmin");
}
TEST_P(Test_TensorFlow_layers, conv_single_conv)
{
runTensorFlowNet("single_conv");