diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 027326c69e..651d1b1571 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -2435,12 +2435,18 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node } else { - inpShape = shape(getBlob(input0)); + Mat blob = getBlob(input0); + if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end() && + getBlobExtraInfo(node_proto, 0).real_ndims == 1) { + inpShape = {(int)blob.total()}; + } else { + inpShape = shape(blob); + } } String srcName = input0; // Unsqueeze and repeat along new axis - if (targetShape.size() == inpShape.size() + 1) + if (targetShape.size() > inpShape.size()) { inpShape.insert(inpShape.begin(), targetShape.size() - inpShape.size(), 1); for (int i = 0; i < targetShape.size(); i++) @@ -2486,7 +2492,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node { if (broadcast_axes.empty()) { - addConstant(output_name, getBlob(node_proto, 0)); + addConstant(output_name, getBlob(node_proto, 0).reshape(1, targetShape)); return; } @@ -2719,7 +2725,8 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node runLayer(layerParams, inputs, output); output.back().convertTo(output.back(), type); - output.back().dims = std::max(input_real_ndims - real_ndims, 1); + if (real_ndims < 2) // In case of scalars or 1D vectors, OpenCV initializes 2D cv::Mat + output.back().dims = std::max(input_real_ndims - real_ndims, 1); addConstant(node_proto.output(0), output.back()); return; } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index b5a97770b1..e566acd827 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2487,6 +2487,11 @@ TEST_P(Test_ONNX_layers, Gelu) testONNXModels("gelu_approximation"); } +TEST_P(Test_ONNX_layers, OpenAI_CLIP_head) +{ + testONNXModels("clip-vit-base-head"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace