diff --git a/modules/dnn/src/tflite/tflite_importer.cpp b/modules/dnn/src/tflite/tflite_importer.cpp index 8850cd9ad2..1c048ad9d0 100644 --- a/modules/dnn/src/tflite/tflite_importer.cpp +++ b/modules/dnn/src/tflite/tflite_importer.cpp @@ -70,6 +70,7 @@ private: void parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); @@ -284,6 +285,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() dispatch["SOFTMAX"] = &TFLiteImporter::parseSoftmax; dispatch["CAST"] = &TFLiteImporter::parseCast; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; + dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose; return dispatch; } @@ -719,6 +721,49 @@ void TFLiteImporter::parseResize(const Operator& op, const std::string& opcode, addLayer(layerParams, op); } +void TFLiteImporter::parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams) +{ + layerParams.type = "Permute"; + std::vector perm = allTensors[op.inputs()->Get(1)]; + + DataLayout inpLayout = layouts[op.inputs()->Get(0)]; + if (inpLayout == DNN_LAYOUT_NHWC && perm.size() == 4) { + + // OpenCV operates under the assumption that NCHW format, whereas TFLite defaults to NHWC. + // Therfore, to align these layouts, the axes of the permutation vector should be adjusted accordingly. + // For implementation details, please refer to the disscusion: + // https://github.com/opencv/opencv/pull/25297#issuecomment-2049762298 + + if (perm[0] != 0) { + CV_Error(Error::StsParseError, "The first axis should not be permuted."); + } + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) { + std::vector orderLP = {0, 1, 2, 3}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW; + } + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) { + std::vector orderLP = {0, 3, 2, 1}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + } + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) { + std::vector orderLP = {0, 1, 3, 2}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW; + } + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) { + std::vector orderLP = {0, 2, 3, 1}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + } + + } + else { + layerParams.set("order", DictValue::arrayInt(perm.data(), perm.size())); + } + + addLayer(layerParams, op); +} + int TFLiteImporter::addPermuteLayer(const std::vector& order, const std::string& permName, const std::pair& inpId, int dtype) { diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 7ad62bf308..7621b44ff5 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -251,6 +251,15 @@ TEST_P(Test_TFLite, fully_connected) { testLayer("fully_connected"); } +TEST_P(Test_TFLite, permute) { + testLayer("permutation_3d"); + // Temporarily disabled as TFLiteConverter produces a incorrect graph in this case + //testLayer("permutation_4d_0123"); + testLayer("permutation_4d_0132"); + testLayer("permutation_4d_0213"); + testLayer("permutation_4d_0231"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); }} // namespace