diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 868a8f06d6..75cba09981 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN */ std::vector > sliceRanges; int axis; + int num_split; static Ptr create(const LayerParams ¶ms); }; diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 73d6a301ae..7640d4637e 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -61,6 +61,7 @@ public: { setParamsFrom(params); axis = params.get("axis", 1); + num_split = params.get("num_split", 0); if (params.has("slice_point")) { CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end")); @@ -141,9 +142,10 @@ public: else // Divide input blob on equal parts by axis. { CV_Assert(0 <= axis && axis < inpShape.size()); - CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0); - inpShape[axis] /= requiredOutputs; - outputs.resize(requiredOutputs, inpShape); + int splits = num_split ? num_split : requiredOutputs; + CV_Assert(splits > 0 && inpShape[axis] % splits == 0); + inpShape[axis] /= splits; + outputs.resize(splits, inpShape); } return false; } diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index c38b250c67..e546d9e1da 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet) axis = toNCHW(axis); layerParams.set("axis", axis); + if (hasLayerAttr(layer, "num_split")) + layerParams.set("num_split", getLayerAttr(layer, "num_split").i()); + int id = dstNet.addLayer(name, "Slice", layerParams); layer_id[name] = id; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 2dae678403..0357b8ecc5 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d) runTensorFlowNet("l2_normalize_3d"); } +TEST_P(Test_TensorFlow_layers, Split) +{ + runTensorFlowNet("split"); +} + class Test_TensorFlow_nets : public DNNTestLayer {}; TEST_P(Test_TensorFlow_nets, MobileNet_SSD)