Merge pull request #10221 from dkurt:non_spatial_torch_layers

pull/10236/head
Vadim Pisarevsky 7 years ago
commit 2c4d3d92c7
  1. 5
      modules/dnn/src/layers/batch_norm_layer.cpp
  2. 20
      modules/dnn/src/torch/torch_importer.cpp
  3. 5
      modules/dnn/test/test_torch_importer.cpp

@ -119,8 +119,9 @@ public:
CV_Assert(inputs.size() == 1); CV_Assert(inputs.size() == 1);
Mat &inpBlob = *inputs[0]; Mat &inpBlob = *inputs[0];
int rows = inpBlob.size[2]; CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
int cols = inpBlob.size[3]; int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
for (size_t ii = 0; ii < outputs.size(); ii++) for (size_t ii = 0; ii < outputs.size(); ii++)
{ {

@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid"))); curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
readObject(); readObject();
} }
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization") else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
nnName == "BatchNormalization")
{ {
newModule->apiType = "BatchNorm"; newModule->apiType = "BatchNorm";
readTorchTable(scalarParams, tensorParams); readTorchTable(scalarParams, tensorParams);
@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
else if (nnName == "SpatialDropout") else if (nnName == "SpatialDropout" || nnName == "Dropout")
{ {
readTorchTable(scalarParams, tensorParams); readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("p")); CV_Assert(scalarParams.has("p"));
float scale = 1 - scalarParams.get<double>("p"); if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
{
newModule->apiType = "Identity";
}
else
{
float scale = 1 - scalarParams.get<double>("p");
CV_Assert(scale > 0); CV_Assert(scale > 0);
newModule->apiType = "Power"; newModule->apiType = "Power";
layerParams.set("scale", scale); layerParams.set("scale", scale);
}
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style // TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style

@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true); runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
} }
TEST(Torch_Importer, net_non_spatial)
{
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, ENet_accuracy) TEST(Torch_Importer, ENet_accuracy)
{ {
Net net; Net net;

Loading…
Cancel
Save