|
|
|
@ -8,7 +8,7 @@ |
|
|
|
|
namespace cv { |
|
|
|
|
namespace dnn { |
|
|
|
|
|
|
|
|
|
#if ENABLE_TORCH_IMPORTER || 1 |
|
|
|
|
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER |
|
|
|
|
#include "THDiskFile.h" |
|
|
|
|
|
|
|
|
|
enum LuaType |
|
|
|
@ -575,32 +575,44 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
} |
|
|
|
|
return prevLayerId; |
|
|
|
|
} |
|
|
|
|
else if (module->thName == "Parallel" || module->thName == "Concat") |
|
|
|
|
else if (module->thName == "Concat") |
|
|
|
|
{ |
|
|
|
|
int splitId, mergeId, newId; |
|
|
|
|
int newId, splitId, mergeId; |
|
|
|
|
LayerParams mergeParams, splitParams; |
|
|
|
|
mergeParams.set("axis", module->params.get<int>("dimension") - 1); |
|
|
|
|
|
|
|
|
|
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams); |
|
|
|
|
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); |
|
|
|
|
net.connect(prevLayerId, prevOutNum, splitId, 0); |
|
|
|
|
|
|
|
|
|
String splitType; |
|
|
|
|
LayerParams splitParams, mergeParams; |
|
|
|
|
if (module->thName == "Parallel") |
|
|
|
|
for (int i = 0; i < (int)module->modules.size(); i++) |
|
|
|
|
{ |
|
|
|
|
splitType = "Slice"; |
|
|
|
|
splitParams.set("axis", module->params.get<int>("inputDimension") - 1); |
|
|
|
|
mergeParams.set("axis", module->params.get<int>("outputDimension") - 1); |
|
|
|
|
newId = fill(module->modules[i], splitId, i); |
|
|
|
|
net.connect(newId, 0, mergeId, i); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
splitType = "Split"; |
|
|
|
|
mergeParams.set("axis", module->params.get<int>("dimension") - 1); |
|
|
|
|
|
|
|
|
|
return mergeId; |
|
|
|
|
} |
|
|
|
|
else if (module->thName == "Parallel") |
|
|
|
|
{ |
|
|
|
|
int newId, splitId, mergeId, reshapeId; |
|
|
|
|
|
|
|
|
|
splitId = net.addLayer(generateLayerName("torchSplit"), splitType, splitParams); |
|
|
|
|
LayerParams splitParams, mergeParams, reshapeParams; |
|
|
|
|
splitParams.set("axis", module->params.get<int>("inputDimension") - 1); |
|
|
|
|
mergeParams.set("axis", module->params.get<int>("outputDimension") - 1); |
|
|
|
|
reshapeParams.set("axis", splitParams.get<int>("axis")); |
|
|
|
|
reshapeParams.set("num_axes", 1); |
|
|
|
|
|
|
|
|
|
splitId = net.addLayer(generateLayerName("torchSplit"), "Slice", splitParams); |
|
|
|
|
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); |
|
|
|
|
reshapeId = net.addLayer(generateLayerName("torchReshape"), "Reshape", reshapeParams); |
|
|
|
|
net.connect(prevLayerId, prevOutNum, splitId, 0); |
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < module->modules.size(); i++) |
|
|
|
|
for (int i = 0; i < (int)module->modules.size(); i++) |
|
|
|
|
{ |
|
|
|
|
newId = fill(module->modules[i], splitId, (int)i); |
|
|
|
|
net.connect(newId, 0, mergeId, (int)i); |
|
|
|
|
net.connect(splitId, i, reshapeId, i); |
|
|
|
|
newId = fill(module->modules[i], reshapeId, i); |
|
|
|
|
net.connect(newId, 0, mergeId, i); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return mergeId; |
|
|
|
|