mirror of https://github.com/opencv/opencv.git
Merge pull request #25297 from CNOCycle:tflite/transpose
Support Transpose op in TFlite #25297 **Merge with extra**: https://github.com/opencv/opencv_extra/pull/1168 The purpose of this PR is to introduce support for the Transpose op in TFlite format and to add a shape comparison between the output tensors and the references. In some occasional cases, the shape of the output tensor is `[1,4,1,1]`, while the shape of the reference tensor is `[1,4]`. Consequently, the norm check incorrectly reports that the test has passed, as the residual is zero. Below is a Python script for generating testing data. The generated data can be integrated into the repo `opencv_extra`. ```python import numpy as np import tensorflow as tf PREFIX_TFL = '/path/to/opencv_extra/testdata/dnn/tflite/' def generator(input_tensor, model, saved_name): # convert keras model to .tflite format converter = tf.lite.TFLiteConverter.from_keras_model(model) #converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [None] tflite_model = converter.convert() with open(f'{PREFIX_TFL}/{saved_name}.tflite', 'wb') as f: f.write(tflite_model) # save the input tensor to .npy if input_tensor.ndim == 4: opencv_tensor = np.transpose(input_tensor, (0,3,1,2)) else: opencv_tensor = input_tensor opencv_tensor = np.copy(opencv_tensor, order='C').astype(np.float32) np.save(f'{PREFIX_TFL}/{saved_name}_inp.npy', opencv_tensor) # generate output tenosr and save it to .npy mat_out = model(input_tensor).numpy() mat_out = np.copy(mat_out, order='C').astype(np.float32) if mat_out.ndim == 4: mat_out = np.transpose(mat_out, (0,3,1,2)) interpreter = tf.lite.Interpreter(model_content=tflite_model) out_name = interpreter.get_output_details()[0]['name'] np.save(f'{PREFIX_TFL}/{saved_name}_out_{out_name}.npy', mat_out) def build_transpose(): model_name = "keras_permute" mat_in = np.array([[[1,2,3], [4,5,6]]], dtype=np.float32) model = tf.keras.Sequential() model.add(tf.keras.Input(shape=(2,3))) model.add(tf.keras.layers.Permute((2,1))) model.summary() generator(mat_in, model, model_name) if __name__ == '__main__': build_transpose() ``` ### Pull Request Readiness Checklist - [x] I agree to contribute to the project under Apache 2 License. - [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [X] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [X] The feature is well documented and sample code can be built with the project CMakepull/25598/head
parent
76d9f7aaeb
commit
7713c84465
2 changed files with 54 additions and 0 deletions
Loading…
Reference in new issue