Merge pull request #9384 from dkurt:torch_split

pull/9591/head
Vadim Pisarevsky 7 years ago
commit 6bf8fe815d
  1. 5
      modules/dnn/src/layers/split_layer.cpp
  2. 23
      modules/dnn/src/torch/torch_importer.cpp

@ -75,7 +75,7 @@ public:
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
outputs, internals);
return true;
return false;
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
@ -86,8 +86,7 @@ public:
for (size_t i = 0; i < outputs.size(); i++)
{
CV_Assert(inputs[0]->total() == outputs[i].total());
if (outputs[i].data != inputs[0]->data)
inputs[0]->copyTo(outputs[i]);
inputs[0]->copyTo(outputs[i]);
}
}
};

@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer
}
else if (module->thName == "Concat")
{
int newId, splitId, mergeId;
LayerParams mergeParams, splitParams;
int newId, mergeId;
LayerParams mergeParams;
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
std::vector<int> branchIds;
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
branchIds.push_back(newId);
}
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
for (int i = 0; i < branchIds.size(); i++)
@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer
return mergeId;
}
else if (module->thName == "ConcatTable") {
int newId = -1, splitId;
LayerParams splitParams;
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
addedModules.push_back(std::make_pair(splitId, module));
int newId = -1;
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
}
return newId;
}
else if (module->thName == "JoinTable") {

Loading…
Cancel
Save