mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Fix embedded Torch's nn.ConcatTable
This commit is contained in:
parent
dbcb454917
commit
598039c0ed
@ -101,6 +101,8 @@ struct TorchImporter
|
||||
std::set<int> readedIndexes;
|
||||
std::map<int, Mat> storages;
|
||||
std::map<int, Mat> tensors;
|
||||
// Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
|
||||
std::vector<int> numUnconnectedLayers;
|
||||
|
||||
struct Module
|
||||
{
|
||||
@ -489,15 +491,7 @@ struct TorchImporter
|
||||
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
|
||||
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
|
||||
}
|
||||
if (nnName == "Concat")
|
||||
{
|
||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
||||
}
|
||||
if (nnName == "JoinTable")
|
||||
{
|
||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
||||
}
|
||||
if (nnName == "DepthConcat")
|
||||
else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
|
||||
{
|
||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
||||
}
|
||||
@ -1096,6 +1090,7 @@ struct TorchImporter
|
||||
{
|
||||
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||
}
|
||||
numUnconnectedLayers.push_back(module->modules.size());
|
||||
return newId;
|
||||
}
|
||||
else if (module->thName == "JoinTable") {
|
||||
@ -1108,9 +1103,14 @@ struct TorchImporter
|
||||
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
||||
addedModules.push_back(std::make_pair(mergeId, module));
|
||||
|
||||
for (int i = 0; i < ids.size(); i++)
|
||||
// Connect to the last number of unconnected layers.
|
||||
CV_Assert(!numUnconnectedLayers.empty());
|
||||
const int numInputs = numUnconnectedLayers.back();
|
||||
numUnconnectedLayers.pop_back();
|
||||
CV_Assert(numInputs <= ids.size());
|
||||
for (int i = 0; i < numInputs; i++)
|
||||
{
|
||||
net.connect(ids[i], 0, mergeId, i);
|
||||
net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
|
||||
}
|
||||
|
||||
return mergeId;
|
||||
@ -1124,9 +1124,14 @@ struct TorchImporter
|
||||
|
||||
int id = net.addLayer(name, "Eltwise", params);
|
||||
|
||||
for (int i = 0; i < ids.size(); i++)
|
||||
// Connect to the last number of unconnected layers.
|
||||
CV_Assert(!numUnconnectedLayers.empty());
|
||||
const int numInputs = numUnconnectedLayers.back();
|
||||
numUnconnectedLayers.pop_back();
|
||||
CV_Assert(numInputs <= ids.size());
|
||||
for (int i = 0; i < numInputs; i++)
|
||||
{
|
||||
net.connect(ids[i], 0, id, i);
|
||||
net.connect(ids[ids.size() - numInputs + i], 0, id, i);
|
||||
}
|
||||
|
||||
addedModules.push_back(std::make_pair(id, module));
|
||||
|
@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
|
||||
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
|
||||
}
|
||||
|
||||
TEST(Torch_Importer, net_residual)
|
||||
{
|
||||
runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user