diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index d65f155a55..115738999a 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -3220,19 +3220,44 @@ void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx addLayer(layerParams, node_proto); } - void ONNXImporter::parseEinsum(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { std::vector einsumInpShapes; for (int j = 0; j < node_proto.input_size(); j++) { - const auto& inputLayerName = node_proto.input(j); - auto it = outShapes.find(inputLayerName); - if (it != outShapes.end()) - { - einsumInpShapes.emplace_back(it->second); + // create Const layer for constants and mark its shape + std::vector input_shape; + if (layer_id.find(node_proto.input(j)) == layer_id.end()) { + Mat blob = getBlob(node_proto, j); + + LayerParams const_params; + const_params.name = node_proto.input(j); + const_params.type = "Const"; + const_params.blobs.push_back(blob); + + opencv_onnx::NodeProto proto; + proto.add_output(const_params.name); + addLayer(const_params, proto); + + input_shape.resize(blob.dims); + for (size_t i = 0; i < input_shape.size(); i++) { + input_shape[i] = blob.size[i]; + } + } + + // also try getting shape from inferred shapes + if (input_shape.empty()) { + const auto& inputLayerName = node_proto.input(j); + auto it = outShapes.find(inputLayerName); + if (it != outShapes.end()) { + input_shape = it->second; + } + } + + if (input_shape.empty()) { + CV_Error(Error::StsAssert, format("ERROR input shape of %s not found", node_proto.input(j).c_str())); } else { - CV_Error(Error::StsAssert, "ERROR input shape not found"); + einsumInpShapes.emplace_back(input_shape); } } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index ad6efbe77a..080b07243c 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1514,6 +1514,10 @@ TEST_P(Test_ONNX_layers, Einsum_transpose) testONNXModels("einsum_transpose", npy, 0, 0, false, false, 1); } +TEST_P(Test_ONNX_layers, Einsum_const_inputs) { + testONNXModels("einsum_const_inputs", npy, 0, 0, false, false, 1); +} + TEST_P(Test_ONNX_layers, Pad2d_Unfused) { testONNXModels("ReflectionPad2d");