mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Import ClipByValue from Keras
This commit is contained in:
parent
8f24db048c
commit
693a7663e7
@ -1641,6 +1641,27 @@ void TFImporter::populateNet(Net dstNet)
|
||||
connect(layer_id, dstNet, Pin(name), flattenId, 0);
|
||||
}
|
||||
}
|
||||
else if (type == "ClipByValue")
|
||||
{
|
||||
// op: "ClipByValue"
|
||||
// input: "input"
|
||||
// input: "mix"
|
||||
// input: "max"
|
||||
CV_Assert(layer.input_size() == 3);
|
||||
|
||||
Mat minValue = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
Mat maxValue = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||
CV_Assert(minValue.total() == 1, minValue.type() == CV_32F,
|
||||
maxValue.total() == 1, maxValue.type() == CV_32F);
|
||||
|
||||
layerParams.set("min_value", minValue.at<float>(0));
|
||||
layerParams.set("max_value", maxValue.at<float>(0));
|
||||
|
||||
int id = dstNet.addLayer(name, "ReLU6", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
|
||||
type == "Relu" || type == "Elu" ||
|
||||
type == "Identity" || type == "Relu6")
|
||||
|
@ -415,6 +415,7 @@ TEST(Test_TensorFlow, softmax)
|
||||
TEST(Test_TensorFlow, relu6)
|
||||
{
|
||||
runTensorFlowNet("keras_relu6");
|
||||
runTensorFlowNet("keras_relu6", DNN_TARGET_CPU, /*hasText*/ true);
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, keras_mobilenet_head)
|
||||
|
Loading…
Reference in New Issue
Block a user