mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Merge pull request #17386 from l-bat:tf_clamp_subgraph
* Added ClipByValue subgraph * Return const nodes
This commit is contained in:
parent
9e09828cc3
commit
ba3cf47600
@ -725,6 +725,21 @@ private:
|
|||||||
bool negativeScales;
|
bool negativeScales;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ClipByValueSubgraph : public TFSubgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ClipByValueSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int maxValue = addNodeToMatch("Const");
|
||||||
|
int minimum = addNodeToMatch("Minimum", input, maxValue);
|
||||||
|
int minValue = addNodeToMatch("Const");
|
||||||
|
addNodeToMatch("Maximum", minimum, minValue);
|
||||||
|
|
||||||
|
setFusedNode("ClipByValue", input, minValue, maxValue);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||||
{
|
{
|
||||||
std::vector<Ptr<Subgraph> > subgraphs;
|
std::vector<Ptr<Subgraph> > subgraphs;
|
||||||
@ -749,6 +764,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
|||||||
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown()));
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown()));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new ClipByValueSubgraph()));
|
||||||
|
|
||||||
for (int i = 0; i < net.node_size(); ++i)
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
{
|
{
|
||||||
|
@ -977,6 +977,11 @@ TEST_P(Test_TensorFlow_layers, tf2_dense)
|
|||||||
runTensorFlowNet("tf2_dense");
|
runTensorFlowNet("tf2_dense");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_TensorFlow_layers, clip_by_value)
|
||||||
|
{
|
||||||
|
runTensorFlowNet("clip_by_value");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, tf2_prelu)
|
TEST_P(Test_TensorFlow_layers, tf2_prelu)
|
||||||
{
|
{
|
||||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||||
|
Loading…
Reference in New Issue
Block a user