mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #21268 from pccvlab:tf_Arg
add argmax and argmin parsing for tensorflow * add argmax and argmin for tf * remove whitespace * remove whitespace * remove static_cast Signed-off-by: Crayon-new <1349159541@qq.com>
This commit is contained in:
parent
f7aa91e660
commit
b4bb98ea60
@ -599,6 +599,8 @@ private:
|
||||
void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
|
||||
void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
|
||||
void parseSquare (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
|
||||
void parseArg (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
|
||||
|
||||
void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
|
||||
};
|
||||
|
||||
@ -677,6 +679,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
|
||||
dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation;
|
||||
dispatch["ExpandDims"] = &TFImporter::parseExpandDims;
|
||||
dispatch["Square"] = &TFImporter::parseSquare;
|
||||
dispatch["ArgMax"] = dispatch["ArgMin"] = &TFImporter::parseArg;
|
||||
|
||||
return dispatch;
|
||||
}
|
||||
@ -2624,6 +2627,22 @@ void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::No
|
||||
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
|
||||
}
|
||||
|
||||
void TFImporter::parseArg(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
|
||||
{
|
||||
const std::string& name = layer.name();
|
||||
const std::string& type = layer.op();
|
||||
|
||||
Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1);
|
||||
layerParams.set("axis", *dimension.ptr<int>());
|
||||
layerParams.set("op", type == "ArgMax" ? "max" : "min");
|
||||
layerParams.set("keepdims", false); //tensorflow doesn't have this atrr, the output's dims minus one(default);
|
||||
|
||||
int id = dstNet.addLayer(name, "Arg", layerParams);
|
||||
layer_id[name] = id;
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
|
||||
void TFImporter::parseCustomLayer(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
|
||||
{
|
||||
// Importer does not know how to map this TensorFlow's operation onto OpenCV's layer.
|
||||
|
@ -185,6 +185,14 @@ TEST_P(Test_TensorFlow_layers, reduce_sum_channel_keep_dims)
|
||||
runTensorFlowNet("reduce_sum_channel", false, 0.0, 0.0, false, "_keep_dims");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, ArgLayer)
|
||||
{
|
||||
if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("Only CPU is supported"); // FIXIT use tags
|
||||
runTensorFlowNet("argmax");
|
||||
runTensorFlowNet("argmin");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, conv_single_conv)
|
||||
{
|
||||
runTensorFlowNet("single_conv");
|
||||
|
Loading…
Reference in New Issue
Block a user