Fix embedded Torch's nn.ConcatTable

This commit is contained in:
Dmitry Kurtaev 2018-03-31 11:11:10 +03:00
parent dbcb454917
commit 598039c0ed
2 changed files with 23 additions and 13 deletions

View File

@ -101,6 +101,8 @@ struct TorchImporter
std::set<int> readedIndexes;
std::map<int, Mat> storages;
std::map<int, Mat> tensors;
// Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
std::vector<int> numUnconnectedLayers;
struct Module
{
@ -489,15 +491,7 @@ struct TorchImporter
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
}
if (nnName == "Concat")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "JoinTable")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "DepthConcat")
else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
@ -1096,6 +1090,7 @@ struct TorchImporter
{
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
}
numUnconnectedLayers.push_back(module->modules.size());
return newId;
}
else if (module->thName == "JoinTable") {
@ -1108,9 +1103,14 @@ struct TorchImporter
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
addedModules.push_back(std::make_pair(mergeId, module));
for (int i = 0; i < ids.size(); i++)
// Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{
net.connect(ids[i], 0, mergeId, i);
net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
}
return mergeId;
@ -1124,9 +1124,14 @@ struct TorchImporter
int id = net.addLayer(name, "Eltwise", params);
for (int i = 0; i < ids.size(); i++)
// Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{
net.connect(ids[i], 0, id, i);
net.connect(ids[ids.size() - numInputs + i], 0, id, i);
}
addedModules.push_back(std::make_pair(id, module));

View File

@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
}
TEST(Torch_Importer, net_residual)
{
runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
}
}