mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #11198 from dkurt:torch_embedded_residuals
This commit is contained in:
commit
099a16bd86
@ -101,6 +101,8 @@ struct TorchImporter
|
|||||||
std::set<int> readedIndexes;
|
std::set<int> readedIndexes;
|
||||||
std::map<int, Mat> storages;
|
std::map<int, Mat> storages;
|
||||||
std::map<int, Mat> tensors;
|
std::map<int, Mat> tensors;
|
||||||
|
// Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
|
||||||
|
std::vector<int> numUnconnectedLayers;
|
||||||
|
|
||||||
struct Module
|
struct Module
|
||||||
{
|
{
|
||||||
@ -489,15 +491,7 @@ struct TorchImporter
|
|||||||
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
|
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
|
||||||
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
|
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
|
||||||
}
|
}
|
||||||
if (nnName == "Concat")
|
else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
|
||||||
{
|
|
||||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
|
||||||
}
|
|
||||||
if (nnName == "JoinTable")
|
|
||||||
{
|
|
||||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
|
||||||
}
|
|
||||||
if (nnName == "DepthConcat")
|
|
||||||
{
|
{
|
||||||
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
layerParams.set("dimension", scalarParams.get<int>("dimension"));
|
||||||
}
|
}
|
||||||
@ -1096,6 +1090,7 @@ struct TorchImporter
|
|||||||
{
|
{
|
||||||
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||||
}
|
}
|
||||||
|
numUnconnectedLayers.push_back(module->modules.size());
|
||||||
return newId;
|
return newId;
|
||||||
}
|
}
|
||||||
else if (module->thName == "JoinTable") {
|
else if (module->thName == "JoinTable") {
|
||||||
@ -1108,9 +1103,14 @@ struct TorchImporter
|
|||||||
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
||||||
addedModules.push_back(std::make_pair(mergeId, module));
|
addedModules.push_back(std::make_pair(mergeId, module));
|
||||||
|
|
||||||
for (int i = 0; i < ids.size(); i++)
|
// Connect to the last number of unconnected layers.
|
||||||
|
CV_Assert(!numUnconnectedLayers.empty());
|
||||||
|
const int numInputs = numUnconnectedLayers.back();
|
||||||
|
numUnconnectedLayers.pop_back();
|
||||||
|
CV_Assert(numInputs <= ids.size());
|
||||||
|
for (int i = 0; i < numInputs; i++)
|
||||||
{
|
{
|
||||||
net.connect(ids[i], 0, mergeId, i);
|
net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mergeId;
|
return mergeId;
|
||||||
@ -1124,9 +1124,14 @@ struct TorchImporter
|
|||||||
|
|
||||||
int id = net.addLayer(name, "Eltwise", params);
|
int id = net.addLayer(name, "Eltwise", params);
|
||||||
|
|
||||||
for (int i = 0; i < ids.size(); i++)
|
// Connect to the last number of unconnected layers.
|
||||||
|
CV_Assert(!numUnconnectedLayers.empty());
|
||||||
|
const int numInputs = numUnconnectedLayers.back();
|
||||||
|
numUnconnectedLayers.pop_back();
|
||||||
|
CV_Assert(numInputs <= ids.size());
|
||||||
|
for (int i = 0; i < numInputs; i++)
|
||||||
{
|
{
|
||||||
net.connect(ids[i], 0, id, i);
|
net.connect(ids[ids.size() - numInputs + i], 0, id, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
addedModules.push_back(std::make_pair(id, module));
|
addedModules.push_back(std::make_pair(id, module));
|
||||||
|
@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
|
|||||||
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
|
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Torch_Importer, net_residual)
|
||||||
|
{
|
||||||
|
runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user