mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 18:13:13 +08:00
Merge pull request #11835 from dkurt:dnn_tf_two_inputs
This commit is contained in:
commit
e87425f047
@ -376,6 +376,8 @@ private:
|
|||||||
// and may be used to build the network using binary format only as a weights storage.
|
// and may be used to build the network using binary format only as a weights storage.
|
||||||
// This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
|
// This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
|
||||||
tensorflow::GraphDef netTxt;
|
tensorflow::GraphDef netTxt;
|
||||||
|
|
||||||
|
std::vector<String> netInputsNames;
|
||||||
};
|
};
|
||||||
|
|
||||||
TFImporter::TFImporter(const char *model, const char *config)
|
TFImporter::TFImporter(const char *model, const char *config)
|
||||||
@ -443,7 +445,14 @@ void TFImporter::connect(const std::map<String, int>& layers_name_id_map, Net& n
|
|||||||
std::map<String, int>::const_iterator it = layers_name_id_map.find(outPin.name);
|
std::map<String, int>::const_iterator it = layers_name_id_map.find(outPin.name);
|
||||||
if (it == layers_name_id_map.end())
|
if (it == layers_name_id_map.end())
|
||||||
CV_Error(Error::StsError, "Input layer not found: " + outPin.name);
|
CV_Error(Error::StsError, "Input layer not found: " + outPin.name);
|
||||||
network.connect(it->second, outPin.blobIndex, input_layer_id, input_blob_id);
|
|
||||||
|
std::vector<String>::iterator inpNameIt = std::find(netInputsNames.begin(), netInputsNames.end(), outPin.name);
|
||||||
|
int blobIndex;
|
||||||
|
if (inpNameIt == netInputsNames.end())
|
||||||
|
blobIndex = outPin.blobIndex;
|
||||||
|
else
|
||||||
|
blobIndex = inpNameIt - netInputsNames.begin();
|
||||||
|
network.connect(it->second, blobIndex, input_layer_id, input_blob_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFImporter::connectToAllBlobs(const std::map<String, int>& layer_id, Net& network, const Pin& outPin,
|
void TFImporter::connectToAllBlobs(const std::map<String, int>& layer_id, Net& network, const Pin& outPin,
|
||||||
@ -845,7 +854,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
Pin inp = parsePin(layer.input(ii));
|
Pin inp = parsePin(layer.input(ii));
|
||||||
if (layer_id.find(inp.name) == layer_id.end())
|
if (layer_id.find(inp.name) == layer_id.end())
|
||||||
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
||||||
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii);
|
connect(layer_id, dstNet, inp, id, ii);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1099,7 +1108,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
Pin inp = parsePin(layer.input(ii));
|
Pin inp = parsePin(layer.input(ii));
|
||||||
if (layer_id.find(inp.name) == layer_id.end())
|
if (layer_id.find(inp.name) == layer_id.end())
|
||||||
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
||||||
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii - from);
|
connect(layer_id, dstNet, inp, id, ii - from);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (type == "MaxPool")
|
else if (type == "MaxPool")
|
||||||
@ -1131,10 +1140,12 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
}
|
}
|
||||||
else if (type == "Placeholder")
|
else if (type == "Placeholder")
|
||||||
{
|
{
|
||||||
std::vector<String> netInputs(1);
|
if (!hasLayerAttr(layer, "dtype") ||
|
||||||
netInputs[0] = name;
|
getLayerAttr(layer, "dtype").type() != tensorflow::DT_BOOL) // If input is not a train/test flag.
|
||||||
|
{
|
||||||
|
netInputsNames.push_back(name);
|
||||||
layer_id[name] = 0;
|
layer_id[name] = 0;
|
||||||
dstNet.setInputsNames(netInputs);
|
}
|
||||||
}
|
}
|
||||||
else if (type == "Split") {
|
else if (type == "Split") {
|
||||||
// TODO: determining axis index remapping by input dimensions order of input blob
|
// TODO: determining axis index remapping by input dimensions order of input blob
|
||||||
@ -1272,7 +1283,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
Pin inp = parsePin(layer.input(ii));
|
Pin inp = parsePin(layer.input(ii));
|
||||||
if (layer_id.find(inp.name) == layer_id.end())
|
if (layer_id.find(inp.name) == layer_id.end())
|
||||||
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
CV_Error(Error::StsError, "Input layer not found: " + inp.name);
|
||||||
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii);
|
connect(layer_id, dstNet, inp, id, ii);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1790,6 +1801,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
dstNet.setInputsNames(netInputsNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -441,4 +441,20 @@ TEST(Test_TensorFlow, resize_bilinear)
|
|||||||
runTensorFlowNet("resize_bilinear_factor");
|
runTensorFlowNet("resize_bilinear_factor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_TensorFlow, two_inputs)
|
||||||
|
{
|
||||||
|
Net net = readNet(path("two_inputs_net.pbtxt"));
|
||||||
|
net.setPreferableBackend(DNN_BACKEND_OPENCV);
|
||||||
|
|
||||||
|
Mat firstInput(2, 3, CV_32FC1), secondInput(2, 3, CV_32FC1);
|
||||||
|
randu(firstInput, -1, 1);
|
||||||
|
randu(secondInput, -1, 1);
|
||||||
|
|
||||||
|
net.setInput(firstInput, "first_input");
|
||||||
|
net.setInput(secondInput, "second_input");
|
||||||
|
Mat out = net.forward();
|
||||||
|
|
||||||
|
normAssert(out, firstInput + secondInput);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user