mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 20:09:23 +08:00
Torch's Concat and ConcatTable doesn't use Split layer
This commit is contained in:
parent
8ffa29473f
commit
0ce7c33bc8
@ -75,7 +75,7 @@ public:
|
||||
|
||||
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
|
||||
outputs, internals);
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
|
||||
@ -86,8 +86,7 @@ public:
|
||||
for (size_t i = 0; i < outputs.size(); i++)
|
||||
{
|
||||
CV_Assert(inputs[0]->total() == outputs[i].total());
|
||||
if (outputs[i].data != inputs[0]->data)
|
||||
inputs[0]->copyTo(outputs[i]);
|
||||
inputs[0]->copyTo(outputs[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -827,20 +827,18 @@ struct TorchImporter : public ::cv::dnn::Importer
|
||||
}
|
||||
else if (module->thName == "Concat")
|
||||
{
|
||||
int newId, splitId, mergeId;
|
||||
LayerParams mergeParams, splitParams;
|
||||
int newId, mergeId;
|
||||
LayerParams mergeParams;
|
||||
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
|
||||
|
||||
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
|
||||
net.connect(prevLayerId, prevOutNum, splitId, 0);
|
||||
|
||||
std::vector<int> branchIds;
|
||||
for (int i = 0; i < (int)module->modules.size(); i++)
|
||||
{
|
||||
newId = fill(module->modules[i], addedModules, splitId, i);
|
||||
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||
branchIds.push_back(newId);
|
||||
}
|
||||
|
||||
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
||||
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
|
||||
|
||||
for (int i = 0; i < branchIds.size(); i++)
|
||||
@ -884,19 +882,12 @@ struct TorchImporter : public ::cv::dnn::Importer
|
||||
return mergeId;
|
||||
}
|
||||
else if (module->thName == "ConcatTable") {
|
||||
int newId = -1, splitId;
|
||||
LayerParams splitParams;
|
||||
|
||||
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
|
||||
net.connect(prevLayerId, prevOutNum, splitId, 0);
|
||||
|
||||
addedModules.push_back(std::make_pair(splitId, module));
|
||||
|
||||
int newId = -1;
|
||||
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
||||
for (int i = 0; i < (int)module->modules.size(); i++)
|
||||
{
|
||||
newId = fill(module->modules[i], addedModules, splitId, i);
|
||||
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
|
||||
}
|
||||
|
||||
return newId;
|
||||
}
|
||||
else if (module->thName == "JoinTable") {
|
||||
|
Loading…
Reference in New Issue
Block a user