mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #9384 from dkurt:torch_split
This commit is contained in:
commit
6bf8fe815d
@ -75,7 +75,7 @@ public:
|
|||||||
|
|
||||||
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
|
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
|
||||||
outputs, internals);
|
outputs, internals);
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
|
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
|
||||||
@ -86,8 +86,7 @@ public:
|
|||||||
for (size_t i = 0; i < outputs.size(); i++)
|
for (size_t i = 0; i < outputs.size(); i++)
|
||||||
{
|
{
|
||||||
CV_Assert(inputs[0]->total() == outputs[i].total());
|
CV_Assert(inputs[0]->total() == outputs[i].total());
|
||||||
if (outputs[i].data != inputs[0]->data)
|
inputs[0]->copyTo(outputs[i]);
|
||||||
inputs[0]->copyTo(outputs[i]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer
|
|||||||
}
|
}
|
||||||
else if (module->thName == "Concat")
|
else if (module->thName == "Concat")
|
||||||
{
|
{
|
||||||
int newId, splitId, mergeId;
|
int newId, mergeId;
|
||||||
LayerParams mergeParams, splitParams;
|
LayerParams mergeParams;
|
||||||
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
|
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
|
||||||
|
|
||||||
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
|
|
||||||
net.connect(prevLayerId, prevOutNum, splitId, 0);
|
|
||||||
|
|
||||||
std::vector<int> branchIds;
|
std::vector<int> branchIds;
|
||||||
for (int i = 0; i < (int)module->modules.size(); i++)
|
for (int i = 0; i < (int)module->modules.size(); i++)
|
||||||
{
|
{
|
||||||
newId = fill(module->modules[i], addedModules, splitId, i);
|
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||||
branchIds.push_back(newId);
|
branchIds.push_back(newId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
||||||
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
||||||
|
|
||||||
for (int i = 0; i < branchIds.size(); i++)
|
for (int i = 0; i < branchIds.size(); i++)
|
||||||
@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer
|
|||||||
return mergeId;
|
return mergeId;
|
||||||
}
|
}
|
||||||
else if (module->thName == "ConcatTable") {
|
else if (module->thName == "ConcatTable") {
|
||||||
int newId = -1, splitId;
|
int newId = -1;
|
||||||
LayerParams splitParams;
|
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
||||||
|
|
||||||
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
|
|
||||||
net.connect(prevLayerId, prevOutNum, splitId, 0);
|
|
||||||
|
|
||||||
addedModules.push_back(std::make_pair(splitId, module));
|
|
||||||
|
|
||||||
for (int i = 0; i < (int)module->modules.size(); i++)
|
for (int i = 0; i < (int)module->modules.size(); i++)
|
||||||
{
|
{
|
||||||
newId = fill(module->modules[i], addedModules, splitId, i);
|
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
return newId;
|
return newId;
|
||||||
}
|
}
|
||||||
else if (module->thName == "JoinTable") {
|
else if (module->thName == "JoinTable") {
|
||||||
|
Loading…
Reference in New Issue
Block a user