mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 03:30:34 +08:00
Disable fusion to output layers
This commit is contained in:
parent
a0baae8a55
commit
f25a01bb5a
@ -2075,7 +2075,8 @@ Mat Net::forward(const String& outputName)
|
||||
if (layerName.empty())
|
||||
layerName = getLayerNames().back();
|
||||
|
||||
impl->setUpNet();
|
||||
std::vector<LayerPin> pins(1, impl->getPinByAlias(layerName));
|
||||
impl->setUpNet(pins);
|
||||
impl->forwardToLayer(impl->getLayerData(layerName));
|
||||
|
||||
return impl->getBlob(layerName);
|
||||
@ -2085,13 +2086,13 @@ void Net::forward(OutputArrayOfArrays outputBlobs, const String& outputName)
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
|
||||
impl->setUpNet();
|
||||
|
||||
String layerName = outputName;
|
||||
|
||||
if (layerName.empty())
|
||||
layerName = getLayerNames().back();
|
||||
|
||||
std::vector<LayerPin> pins(1, impl->getPinByAlias(layerName));
|
||||
impl->setUpNet(pins);
|
||||
impl->forwardToLayer(impl->getLayerData(layerName));
|
||||
|
||||
LayerPin pin = impl->getPinByAlias(layerName);
|
||||
|
@ -1240,4 +1240,36 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Test_ShuffleChannel, Combine(
|
||||
/*group*/ Values(1, 2, 3, 6)
|
||||
));
|
||||
|
||||
// Check if relu is not fused to convolution if we requested it's output
|
||||
TEST(Layer_Test_Convolution, relu_fusion)
|
||||
{
|
||||
Net net;
|
||||
{
|
||||
LayerParams lp;
|
||||
lp.set("kernel_size", 1);
|
||||
lp.set("num_output", 1);
|
||||
lp.set("bias_term", false);
|
||||
lp.type = "Convolution";
|
||||
lp.name = "testConv";
|
||||
|
||||
int weightsShape[] = {1, 1, 1, 1};
|
||||
Mat weights(4, &weightsShape[0], CV_32F, Scalar(1));
|
||||
lp.blobs.push_back(weights);
|
||||
net.addLayerToPrev(lp.name, lp.type, lp);
|
||||
}
|
||||
{
|
||||
LayerParams lp;
|
||||
lp.type = "ReLU";
|
||||
lp.name = "testReLU";
|
||||
net.addLayerToPrev(lp.name, lp.type, lp);
|
||||
}
|
||||
int sz[] = {1, 1, 2, 3};
|
||||
Mat input(4, &sz[0], CV_32F);
|
||||
randu(input, -1.0, -0.1);
|
||||
net.setInput(input);
|
||||
net.setPreferableBackend(DNN_BACKEND_OPENCV);
|
||||
Mat output = net.forward("testConv");
|
||||
normAssert(input, output);
|
||||
}
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user