diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 91110cb15e..b7d289d202 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -786,37 +786,42 @@ void ONNXImporter::populateNet(Net dstNet) } replaceLayerParam(layerParams, "mode", "interpolation"); } + else if (layer_type == "LogSoftmax") + { + layerParams.type = "Softmax"; + layerParams.set("log_softmax", true); + } else { for (int j = 0; j < node_proto.input_size(); j++) { if (layer_id.find(node_proto.input(j)) == layer_id.end()) layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j)); } - } - - int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams); - layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0))); - - - std::vector layerInpShapes, layerOutShapes, layerInternalShapes; - for (int j = 0; j < node_proto.input_size(); j++) { - layerId = layer_id.find(node_proto.input(j)); - if (layerId != layer_id.end()) { - dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j); - // Collect input shapes. - shapeIt = outShapes.find(node_proto.input(j)); - CV_Assert(shapeIt != outShapes.end()); - layerInpShapes.push_back(shapeIt->second); - } - } - - // Compute shape of output blob for this layer. - Ptr layer = dstNet.getLayer(id); - layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes); - CV_Assert(!layerOutShapes.empty()); - outShapes[layerParams.name] = layerOutShapes[0]; - } - } + } + + int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams); + layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0))); + + + std::vector layerInpShapes, layerOutShapes, layerInternalShapes; + for (int j = 0; j < node_proto.input_size(); j++) { + layerId = layer_id.find(node_proto.input(j)); + if (layerId != layer_id.end()) { + dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j); + // Collect input shapes. + shapeIt = outShapes.find(node_proto.input(j)); + CV_Assert(shapeIt != outShapes.end()); + layerInpShapes.push_back(shapeIt->second); + } + } + + // Compute shape of output blob for this layer. + Ptr layer = dstNet.getLayer(id); + layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes); + CV_Assert(!layerOutShapes.empty()); + outShapes[layerParams.name] = layerOutShapes[0]; + } +} Net readNetFromONNX(const String& onnxFile) { diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index bf9f25d214..eb306283bb 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -245,6 +245,12 @@ TEST_P(Test_ONNX_layers, Reshape) testONNXModels("unsqueeze"); } +TEST_P(Test_ONNX_layers, Softmax) +{ + testONNXModels("softmax"); + testONNXModels("log_softmax"); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); class Test_ONNX_nets : public Test_ONNX_layers {};