Merge pull request #10221 from dkurt:non_spatial_torch_layers

pull/10236/head
Vadim Pisarevsky 7 years ago
commit 2c4d3d92c7
  1. 5
      modules/dnn/src/layers/batch_norm_layer.cpp
  2. 20
      modules/dnn/src/torch/torch_importer.cpp
  3. 5
      modules/dnn/test/test_torch_importer.cpp

@ -119,8 +119,9 @@ public:
CV_Assert(inputs.size() == 1);
Mat &inpBlob = *inputs[0];
int rows = inpBlob.size[2];
int cols = inpBlob.size[3];
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
for (size_t ii = 0; ii < outputs.size(); ii++)
{

@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
readObject();
}
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization")
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
nnName == "BatchNormalization")
{
newModule->apiType = "BatchNorm";
readTorchTable(scalarParams, tensorParams);
@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(newModule);
}
else if (nnName == "SpatialDropout")
else if (nnName == "SpatialDropout" || nnName == "Dropout")
{
readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("p"));
float scale = 1 - scalarParams.get<double>("p");
if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
{
newModule->apiType = "Identity";
}
else
{
float scale = 1 - scalarParams.get<double>("p");
CV_Assert(scale > 0);
CV_Assert(scale > 0);
newModule->apiType = "Power";
layerParams.set("scale", scale);
newModule->apiType = "Power";
layerParams.set("scale", scale);
}
curModule->modules.push_back(newModule);
}
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style

@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, net_non_spatial)
{
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, ENet_accuracy)
{
Net net;

Loading…
Cancel
Save