From 78ff9d931ac1498a2b51854746f7c477c971f0f8 Mon Sep 17 00:00:00 2001 From: dkurt Date: Mon, 5 Jun 2017 10:54:17 +0300 Subject: [PATCH] Import SoftMax, LogSoftMax layers from Torch --- modules/dnn/include/opencv2/dnn/all_layers.hpp | 2 ++ modules/dnn/include/opencv2/dnn/dnn.hpp | 1 + modules/dnn/src/layers/softmax_layer.cpp | 9 +++++++++ modules/dnn/src/torch/torch_importer.cpp | 11 +++++++++++ modules/dnn/test/test_torch_importer.cpp | 12 ++++++++++++ 5 files changed, 35 insertions(+) diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 47784d614..3954107df 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -251,6 +251,8 @@ namespace dnn class CV_EXPORTS SoftmaxLayer : public Layer { public: + bool logSoftMax; + static Ptr create(const LayerParams& params); }; diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 83d9a3976..ce671a8b7 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity. * - nn.SpatialMaxPooling, nn.SpatialAveragePooling * - nn.ReLU, nn.TanH, nn.Sigmoid * - nn.Reshape + * - nn.SoftMax, nn.LogSoftMax * * Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported. */ diff --git a/modules/dnn/src/layers/softmax_layer.cpp b/modules/dnn/src/layers/softmax_layer.cpp index a46894476..00dbc2672 100644 --- a/modules/dnn/src/layers/softmax_layer.cpp +++ b/modules/dnn/src/layers/softmax_layer.cpp @@ -57,6 +57,7 @@ public: SoftMaxLayerImpl(const LayerParams& params) { axisRaw = params.get("axis", 1); + logSoftMax = params.get("log_softmax", false); setParamsFrom(params); } @@ -143,6 +144,14 @@ public: for (size_t i = 0; i < innerSize; i++) dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i]; } + if (logSoftMax) + { + for (size_t cnDim = 0; cnDim < channels; cnDim++) + { + for (size_t i = 0; i < innerSize; i++) + dstPtr[srcOffset + cnDim * cnStep + i] = log(dstPtr[srcOffset + cnDim * cnStep + i]); + } + } } } diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index d1112de32..c1500e6cd 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer layerParams.set("indices_blob_id", tensorParams["indices"].first); curModule->modules.push_back(newModule); } + else if (nnName == "SoftMax") + { + newModule->apiType = "SoftMax"; + curModule->modules.push_back(newModule); + } + else if (nnName == "LogSoftMax") + { + newModule->apiType = "SoftMax"; + layerParams.set("log_softmax", true); + curModule->modules.push_back(newModule); + } else { CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\""); diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index e9af98b15..8e40aac1c 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table) runTorchNet("net_cadd_table"); } +TEST(Torch_Importer, net_softmax) +{ + runTorchNet("net_softmax"); + runTorchNet("net_softmax_spatial"); +} + +TEST(Torch_Importer, net_logsoftmax) +{ + runTorchNet("net_logsoftmax"); + runTorchNet("net_logsoftmax_spatial"); +} + TEST(Torch_Importer, ENet_accuracy) { Net net;