diff --git a/modules/dnn/src/layers/reshape_layer.cpp b/modules/dnn/src/layers/reshape_layer.cpp index a98e4e962..ba5d9c8ce 100644 --- a/modules/dnn/src/layers/reshape_layer.cpp +++ b/modules/dnn/src/layers/reshape_layer.cpp @@ -55,6 +55,7 @@ static void computeShapeByReshapeMask(const MatShape &srcShape, { int srcShapeSize = (int)srcShape.size(); int maskShapeSize = (int)maskShape.size(); + int maskTotal = abs(total(maskShape)); // Mask might have negative ones. if (srcRange == Range::all()) srcRange = Range(0, srcShapeSize); @@ -65,6 +66,19 @@ static void computeShapeByReshapeMask(const MatShape &srcShape, srcRange.end = srcRange.end == INT_MAX ? srcShapeSize : srcRange.start + sz; } + if (maskTotal != 0) + { + for (int i = srcRange.start + 1; i < srcRange.end; ++i) + { + if (total(srcShape, i, srcRange.end) != maskTotal) + { + srcRange.start = i - 1; + break; + } + } + CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal); + } + CV_Assert(0 <= srcRange.start && srcRange.start <= srcRange.end && srcRange.end <= srcShapeSize); int dstShapeSize = srcShapeSize - srcRange.size() + maskShapeSize; dstShape.resize(dstShapeSize); diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index e9af98b15..450e8b0d0 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -122,6 +122,7 @@ TEST(Torch_Importer, run_reshape) { runTorchNet("net_reshape"); runTorchNet("net_reshape_batch"); + runTorchNet("net_reshape_single_sample"); } TEST(Torch_Importer, run_linear)