Import SoftMax, LogSoftMax layers from Torch

pull/1212/head
dkurt 8 years ago
parent 9638b1454a
commit 78ff9d931a
  1. 2
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  2. 1
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 9
      modules/dnn/src/layers/softmax_layer.cpp
  4. 11
      modules/dnn/src/torch/torch_importer.cpp
  5. 12
      modules/dnn/test/test_torch_importer.cpp

@ -251,6 +251,8 @@ namespace dnn
class CV_EXPORTS SoftmaxLayer : public Layer
{
public:
bool logSoftMax;
static Ptr<SoftmaxLayer> create(const LayerParams& params);
};

@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* - nn.SpatialMaxPooling, nn.SpatialAveragePooling
* - nn.ReLU, nn.TanH, nn.Sigmoid
* - nn.Reshape
* - nn.SoftMax, nn.LogSoftMax
*
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/

@ -57,6 +57,7 @@ public:
SoftMaxLayerImpl(const LayerParams& params)
{
axisRaw = params.get<int>("axis", 1);
logSoftMax = params.get<int>("log_softmax", false);
setParamsFrom(params);
}
@ -143,6 +144,14 @@ public:
for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i];
}
if (logSoftMax)
{
for (size_t cnDim = 0; cnDim < channels; cnDim++)
{
for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] = log(dstPtr[srcOffset + cnDim * cnStep + i]);
}
}
}
}

@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer
layerParams.set("indices_blob_id", tensorParams["indices"].first);
curModule->modules.push_back(newModule);
}
else if (nnName == "SoftMax")
{
newModule->apiType = "SoftMax";
curModule->modules.push_back(newModule);
}
else if (nnName == "LogSoftMax")
{
newModule->apiType = "SoftMax";
layerParams.set("log_softmax", true);
curModule->modules.push_back(newModule);
}
else
{
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");

@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table)
runTorchNet("net_cadd_table");
}
TEST(Torch_Importer, net_softmax)
{
runTorchNet("net_softmax");
runTorchNet("net_softmax_spatial");
}
TEST(Torch_Importer, net_logsoftmax)
{
runTorchNet("net_logsoftmax");
runTorchNet("net_logsoftmax_spatial");
}
TEST(Torch_Importer, ENet_accuracy)
{
Net net;

Loading…
Cancel
Save