|
|
|
@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
} |
|
|
|
|
else if (module->thName == "Concat") |
|
|
|
|
{ |
|
|
|
|
int newId, splitId, mergeId; |
|
|
|
|
LayerParams mergeParams, splitParams; |
|
|
|
|
int newId, mergeId; |
|
|
|
|
LayerParams mergeParams; |
|
|
|
|
mergeParams.set("axis", module->params.get<int>("dimension") - 1); |
|
|
|
|
|
|
|
|
|
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams); |
|
|
|
|
net.connect(prevLayerId, prevOutNum, splitId, 0); |
|
|
|
|
|
|
|
|
|
std::vector<int> branchIds; |
|
|
|
|
for (int i = 0; i < (int)module->modules.size(); i++) |
|
|
|
|
{ |
|
|
|
|
newId = fill(module->modules[i], addedModules, splitId, i); |
|
|
|
|
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); |
|
|
|
|
branchIds.push_back(newId); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
|
|
|
|
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < branchIds.size(); i++) |
|
|
|
@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
return mergeId; |
|
|
|
|
} |
|
|
|
|
else if (module->thName == "ConcatTable") { |
|
|
|
|
int newId = -1, splitId; |
|
|
|
|
LayerParams splitParams; |
|
|
|
|
|
|
|
|
|
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams); |
|
|
|
|
net.connect(prevLayerId, prevOutNum, splitId, 0); |
|
|
|
|
|
|
|
|
|
addedModules.push_back(std::make_pair(splitId, module)); |
|
|
|
|
|
|
|
|
|
int newId = -1; |
|
|
|
|
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
|
|
|
|
|
for (int i = 0; i < (int)module->modules.size(); i++) |
|
|
|
|
{ |
|
|
|
|
newId = fill(module->modules[i], addedModules, splitId, i); |
|
|
|
|
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return newId; |
|
|
|
|
} |
|
|
|
|
else if (module->thName == "JoinTable") { |
|
|
|
|