Merge pull request #24298 from WanliZhong:extend_perf_net_test

Extend performance test models #24298

**Merged With https://github.com/opencv/opencv_extra/pull/1095**

This PR aims to extend the performance tests. 

- **YOLOv5** for object detection
- **YOLOv8** for object detection
- **EfficientNet** for classification

Models from OpenCV Zoo:

- **YOLOX** for object detection
- **YuNet** for face detection
- **SFace** for face recognization
- **MPPalm** for palm detection
- **MPHand** for hand landmark
- **MPPose** for pose estimation
- **ViTTrack** for object tracking
- **PPOCRv3** for text detection
- **CRNN** for text recognization
- **PPHumanSeg** for human segmentation

If other models should be added, **please leave some comments**. Thanks!



Build opencv with script:
```shell
-DBUILD_opencv_python2=OFF
-DBUILD_opencv_python3=OFF
-DBUILD_opencv_gapi=OFF
-DINSTALL_PYTHON_EXAMPLES=OFF
-DINSTALL_C_EXAMPLES=OFF
-DBUILD_DOCS=OFF
-DBUILD_EXAMPLES=OFF
-DBUILD_ZLIB=OFF
-DWITH_FFMPEG=OFF
```



Performance Test on **Apple M2 CPU**
```shell
MacOS 14.0
8 threads
```

**1 thread:**
| Name of Test | 4.5.5-1th | 4.6.0-1th | 4.7.0-1th | 4.8.0-1th | 4.8.1-1th |
|--------------|:---------:|:---------:|:---------:|:---------:|:---------:|
| CRNN         |  76.244   |  76.611   |  62.534   |  57.678   |  57.238   |
| EfficientNet |    ---    |    ---    |  109.224  |  130.753  |  109.076  |
| MPHand       |    ---    |    ---    |  19.289   |  22.727   |  27.593   |
| MPPalm       |  47.150   |  47.061   |  41.064   |  65.598   |  40.109   |
| MPPose       |    ---    |    ---    |  26.592   |  32.022   |  26.956   |
| PPHumanSeg   |  41.672   |  41.790   |  27.819   |  27.212   |  30.461   |
| PPOCRv3      |    ---    |    ---    |  140.371  |  187.922  |  170.026  |
| SFace        |  43.830   |  43.834   |  27.575   |  30.653   |  26.387   |
| ViTTrack     |    ---    |    ---    |    ---    |  14.617   |  15.028   |
| YOLOX        | 1060.507  | 1061.361  |  495.816  |  533.309  |  549.713  |
| YOLOv5       |    ---    |    ---    |    ---    |  191.350  |  193.261  |
| YOLOv8       |    ---    |    ---    |  198.893  |  218.733  |  223.142  |
| YuNet        |  27.084   |  27.095   |  26.238   |  30.512   |  34.439   |
| MobileNet_SSD_Caffe         |  44.742   |  44.565   |  33.005   |  29.421   |  29.286   |
| MobileNet_SSD_v1_TensorFlow |  49.352   |  49.274   |  35.163   |  32.134   |  31.904   |
| MobileNet_SSD_v2_TensorFlow |  83.537   |  83.379   |  56.403   |  42.947   |  42.148   |
| ResNet_50                   |  148.872  |  148.817  |  77.331   |  67.682   |  67.760   |


**n threads:**
| Name of Test | 4.5.5-nth | 4.6.0-nth | 4.7.0-nth | 4.8.0-nth | 4.8.1-nth |
|--------------|:---------:|:---------:|:---------:|:---------:|:---------:|
| CRNN         |  44.262   |  44.408   |  41.540   |  40.731   |  41.151   |
| EfficientNet |    ---    |    ---    |  28.683   |  42.676   |  38.204   |
| MPHand       |    ---    |    ---    |   6.738   |  13.126   |   8.155   |
| MPPalm       |  16.613   |  16.588   |  12.477   |  31.370   |  17.048   |
| MPPose       |    ---    |    ---    |  12.985   |  19.700   |  16.537   |
| PPHumanSeg   |  14.993   |  15.133   |  13.438   |  15.269   |  15.252   |
| PPOCRv3      |    ---    |    ---    |  63.752   |  85.469   |  76.190   |
| SFace        |  10.685   |  10.822   |   8.127   |   8.318   |   7.934   |
| ViTTrack     |    ---    |    ---    |    ---    |  10.079   |   9.579   |
| YOLOX        |  417.358  |  422.977  |  230.036  |  234.662  |  228.555  |
| YOLOv5       |    ---    |    ---    |    ---    |  74.249   |  75.480   |
| YOLOv8       |    ---    |    ---    |  63.762   |  88.770   |  70.927   |
| YuNet        |   8.589   |   8.731   |  11.269   |  16.466   |  14.513   |
| MobileNet_SSD_Caffe         |  12.575   |  12.636   |  11.529   |  12.114   |  12.236   |
| MobileNet_SSD_v1_TensorFlow |  13.922   |  14.160   |  13.078   |  12.124   |  13.298   |
| MobileNet_SSD_v2_TensorFlow |  25.096   |  24.836   |  22.823   |  20.238   |  20.319   |
| ResNet_50                   |  41.561   |  41.296   |  29.092   |  30.412   |  29.339   |


Performance Test on [Intel Core i7-12700K](https://www.intel.com/content/www/us/en/products/sku/134594/intel-core-i712700k-processor-25m-cache-up-to-5-00-ghz/specifications.html)
```shell
Ubuntu 22.04.2 LTS
8 Performance-cores (3.60 GHz, turbo up to 4.90 GHz)
4 Efficient-cores (2.70 GHz, turbo up to 3.80 GHz)
20 threads
```


**1 thread:**
| Name of Test | 4.5.5-1th | 4.6.0-1th | 4.7.0-1th | 4.8.0-1th | 4.8.1-1th |
|--------------|:---------:|:---------:|:---------:|:---------:|:---------:|
| CRNN         |  16.752   |  16.851   |  16.840   |  16.625   |  16.663   |
| EfficientNet |    ---    |    ---    |  61.107   |  76.037   |  53.890   |
| MPHand       |    ---    |    ---    |   8.906   |   9.969   |   8.403   |
| MPPalm       |  24.243   |  24.638   |  18.104   |  35.140   |  18.387   |
| MPPose       |    ---    |    ---    |  12.322   |  16.515   |  12.355   |
| PPHumanSeg   |  15.249   |  15.303   |  10.203   |  10.298   |  10.353   |
| PPOCRv3      |    ---    |    ---    |  87.788   |  144.253  |  90.648   |
| SFace        |  15.583   |  15.884   |  13.957   |  13.298   |  13.284   |
| ViTTrack     |    ---    |    ---    |    ---    |  11.760   |  11.710   |
| YOLOX        |  324.927  |  325.173  |  235.986  |  253.653  |  254.472  |
| YOLOv5       |    ---    |    ---    |    ---    |  102.163  |  102.621  |
| YOLOv8       |    ---    |    ---    |  87.013   |  103.182  |  103.146  |
| YuNet        |  12.806   |  12.645   |  10.515   |  12.647   |  12.711   |
| MobileNet_SSD_Caffe         |  23.556   |  23.768   |  24.304   |  22.569   |  22.602   |
| MobileNet_SSD_v1_TensorFlow |  26.136   |  26.276   |  26.854   |  24.828   |  24.961   |
| MobileNet_SSD_v2_TensorFlow |  43.521   |  43.614   |  46.892   |  44.044   |  44.682   |
| ResNet_50                   |  73.588   |  73.501   |  75.191   |  66.893   |  65.144   |


**n thread:**
| Name of Test | 4.5.5-nth | 4.6.0-nth | 4.7.0-nth | 4.8.0-nth | 4.8.1-nth | 
|--------------|:---------:|:---------:|:---------:|:---------:|:---------:|
| CRNN         |   8.665   |   8.827   |  10.643   |   7.703   |   7.743   | 
| EfficientNet |    ---    |    ---    |  16.591   |  12.715   |   9.022   |   
| MPHand       |    ---    |    ---    |   2.678   |   2.785   |   1.680   |           
| MPPalm       |   5.309   |   5.319   |   3.822   |  10.568   |   4.467   |       
| MPPose       |    ---    |    ---    |   3.644   |   6.088   |   4.608   |        
| PPHumanSeg   |   4.756   |   4.865   |   5.084   |   5.179   |   5.148   |        
| PPOCRv3      |    ---    |    ---    |  32.023   |  50.591   |  32.414   |      
| SFace        |   3.838   |   3.980   |   4.629   |   3.145   |   3.155   |       
| ViTTrack     |    ---    |    ---    |    ---    |  10.335   |  10.357   |   
| YOLOX        |  68.314   |  68.081   |  82.801   |  74.219   |  73.970   |      
| YOLOv5       |    ---    |    ---    |    ---    |  47.150   |  47.523   |    
| YOLOv8       |    ---    |    ---    |  32.195   |  30.359   |  30.267   |    
| YuNet        |   2.604   |   2.644   |   2.622   |   3.278   |   3.349   |    
| MobileNet_SSD_Caffe         |  13.005   |   5.935   |   8.586   |   4.629   |   4.713   |
| MobileNet_SSD_v1_TensorFlow |   7.002   |   7.129   |   9.314   |   5.271   |   5.213   |
| MobileNet_SSD_v2_TensorFlow |  11.939   |  12.111   |  22.688   |  12.038   |  12.086   |
| ResNet_50                   |  18.227   |  18.600   |  26.150   |  15.584   |  15.706   |
pull/24359/head
Wanli 2 years ago committed by GitHub
parent 670c52f75e
commit 62b5470b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 177
      modules/dnn/perf/perf_net.cpp

@ -29,10 +29,7 @@ public:
}
void processNet(std::string weights, std::string proto, std::string halide_scheduler,
const Mat& input, const std::string& outputLayer = "")
{
randu(input, 0.0f, 1.0f);
const std::vector<std::tuple<Mat, std::string>>& inputs, const std::string& outputLayer = ""){
weights = findDataFile(weights, false);
if (!proto.empty())
proto = findDataFile(proto);
@ -44,7 +41,11 @@ public:
halide_scheduler = findDataFile(std::string("dnn/halide_scheduler_") + (target == DNN_TARGET_OPENCL ? "opencl_" : "") + halide_scheduler, true);
}
net = readNet(proto, weights);
net.setInput(blobFromImage(input, 1.0, Size(), Scalar(), false));
// Set multiple inputs
for(auto &inp: inputs){
net.setInput(std::get<0>(inp), std::get<1>(inp));
}
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (backend == DNN_BACKEND_HALIDE)
@ -52,10 +53,14 @@ public:
net.setHalideScheduler(halide_scheduler);
}
MatShape netInputShape = shape(1, 3, input.rows, input.cols);
// Calculate multiple inputs memory consumption
std::vector<MatShape> netMatShapes;
for(auto &inp: inputs){
netMatShapes.push_back(shape(std::get<0>(inp)));
}
size_t weightsMemory = 0, blobsMemory = 0;
net.getMemoryConsumption(netInputShape, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netInputShape);
net.getMemoryConsumption(netMatShapes, weightsMemory, blobsMemory);
int64 flops = net.getFLOPS(netMatShapes);
CV_Assert(flops > 0);
net.forward(outputLayer); // warmup
@ -71,31 +76,46 @@ public:
SANITY_CHECK_NOTHING();
}
void processNet(std::string weights, std::string proto, std::string halide_scheduler,
Mat &input, const std::string& outputLayer = "")
{
processNet(weights, proto, halide_scheduler, {std::make_tuple(input, "")}, outputLayer);
}
void processNet(std::string weights, std::string proto, std::string halide_scheduler,
Size inpSize, const std::string& outputLayer = "")
{
Mat input_data(inpSize, CV_32FC3);
randu(input_data, 0.0f, 1.0f);
Mat input = blobFromImage(input_data, 1.0, Size(), Scalar(), false);
processNet(weights, proto, halide_scheduler, input, outputLayer);
}
};
PERF_TEST_P_(DNNTestNetwork, AlexNet)
{
processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
"alexnet.yml", Mat(cv::Size(227, 227), CV_32FC3));
"alexnet.yml", cv::Size(227, 227));
}
PERF_TEST_P_(DNNTestNetwork, GoogLeNet)
{
processNet("dnn/bvlc_googlenet.caffemodel", "dnn/bvlc_googlenet.prototxt",
"", Mat(cv::Size(224, 224), CV_32FC3));
"", cv::Size(224, 224));
}
PERF_TEST_P_(DNNTestNetwork, ResNet_50)
{
processNet("dnn/ResNet-50-model.caffemodel", "dnn/ResNet-50-deploy.prototxt",
"resnet_50.yml", Mat(cv::Size(224, 224), CV_32FC3));
"resnet_50.yml", cv::Size(224, 224));
}
PERF_TEST_P_(DNNTestNetwork, SqueezeNet_v1_1)
{
processNet("dnn/squeezenet_v1.1.caffemodel", "dnn/squeezenet_v1.1.prototxt",
"squeezenet_v1_1.yml", Mat(cv::Size(227, 227), CV_32FC3));
"squeezenet_v1_1.yml", cv::Size(227, 227));
}
PERF_TEST_P_(DNNTestNetwork, Inception_5h)
@ -103,7 +123,7 @@ PERF_TEST_P_(DNNTestNetwork, Inception_5h)
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) throw SkipTestException("");
processNet("dnn/tensorflow_inception_graph.pb", "",
"inception_5h.yml",
Mat(cv::Size(224, 224), CV_32FC3), "softmax2");
cv::Size(224, 224), "softmax2");
}
PERF_TEST_P_(DNNTestNetwork, ENet)
@ -116,13 +136,13 @@ PERF_TEST_P_(DNNTestNetwork, ENet)
throw SkipTestException("");
#endif
processNet("dnn/Enet-model-best.net", "", "enet.yml",
Mat(cv::Size(512, 256), CV_32FC3));
cv::Size(512, 256));
}
PERF_TEST_P_(DNNTestNetwork, SSD)
{
processNet("dnn/VGG_ILSVRC2016_SSD_300x300_iter_440000.caffemodel", "dnn/ssd_vgg16.prototxt", "disabled",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, OpenFace)
@ -134,7 +154,7 @@ PERF_TEST_P_(DNNTestNetwork, OpenFace)
throw SkipTestException("");
#endif
processNet("dnn/openface_nn4.small2.v1.t7", "", "",
Mat(cv::Size(96, 96), CV_32FC3));
cv::Size(96, 96));
}
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
@ -142,7 +162,7 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/MobileNetSSD_deploy_19e3ec3.caffemodel", "dnn/MobileNetSSD_deploy_19e3ec3.prototxt", "",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_v1_TensorFlow)
@ -150,7 +170,7 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_v1_TensorFlow)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/ssd_mobilenet_v1_coco_2017_11_17.pb", "ssd_mobilenet_v1_coco_2017_11_17.pbtxt", "",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_v2_TensorFlow)
@ -158,7 +178,7 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_v2_TensorFlow)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/ssd_mobilenet_v2_coco_2018_03_29.pb", "ssd_mobilenet_v2_coco_2018_03_29.pbtxt", "",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, DenseNet_121)
@ -166,7 +186,7 @@ PERF_TEST_P_(DNNTestNetwork, DenseNet_121)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/DenseNet_121.caffemodel", "dnn/DenseNet_121.prototxt", "",
Mat(cv::Size(224, 224), CV_32FC3));
cv::Size(224, 224));
}
PERF_TEST_P_(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
@ -177,7 +197,7 @@ PERF_TEST_P_(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
// The same .caffemodel but modified .prototxt
// See https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/pose/poseParameters.cpp
processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi_faster_4_stages.prototxt", "",
Mat(cv::Size(368, 368), CV_32FC3));
cv::Size(368, 368));
}
PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
@ -185,7 +205,7 @@ PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt", "",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
@ -193,7 +213,7 @@ PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "ssd_inception_v2_coco_2017_11_17.pbtxt", "",
Mat(cv::Size(300, 300), CV_32FC3));
cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, YOLOv3)
@ -213,9 +233,7 @@ PERF_TEST_P_(DNNTestNetwork, YOLOv3)
#endif
Mat sample = imread(findDataFile("dnn/dog416.png"));
cvtColor(sample, sample, COLOR_BGR2RGB);
Mat inp;
sample.convertTo(inp, CV_32FC3, 1.0f / 255, 0);
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(), Scalar(), true);
processNet("dnn/yolov3.weights", "dnn/yolov3.cfg", "", inp);
}
@ -233,9 +251,7 @@ PERF_TEST_P_(DNNTestNetwork, YOLOv4)
throw SkipTestException("Test is disabled in OpenVINO 2020.4");
#endif
Mat sample = imread(findDataFile("dnn/dog416.png"));
cvtColor(sample, sample, COLOR_BGR2RGB);
Mat inp;
sample.convertTo(inp, CV_32FC3, 1.0f / 255, 0);
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(), Scalar(), true);
processNet("dnn/yolov4.weights", "dnn/yolov4.cfg", "", inp);
}
@ -248,24 +264,43 @@ PERF_TEST_P_(DNNTestNetwork, YOLOv4_tiny)
throw SkipTestException("");
#endif
Mat sample = imread(findDataFile("dnn/dog416.png"));
cvtColor(sample, sample, COLOR_BGR2RGB);
Mat inp;
sample.convertTo(inp, CV_32FC3, 1.0f / 255, 0);
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(), Scalar(), true);
processNet("dnn/yolov4-tiny-2020-12.weights", "dnn/yolov4-tiny-2020-12.cfg", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, YOLOv5) {
applyTestTag(CV_TEST_TAG_MEMORY_512MB);
Mat sample = imread(findDataFile("dnn/dog416.png"));
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(640, 640), Scalar(), true);
processNet("", "dnn/yolov5n.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, YOLOv8) {
applyTestTag(CV_TEST_TAG_MEMORY_512MB);
Mat sample = imread(findDataFile("dnn/dog416.png"));
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(640, 640), Scalar(), true);
processNet("", "dnn/yolov8n.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, YOLOX) {
applyTestTag(CV_TEST_TAG_MEMORY_512MB);
Mat sample = imread(findDataFile("dnn/dog416.png"));
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(640, 640), Scalar(), true);
processNet("", "dnn/yolox_s.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, EAST_text_detection)
{
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/frozen_east_text_detection.pb", "", "", Mat(cv::Size(320, 320), CV_32FC3));
processNet("dnn/frozen_east_text_detection.pb", "", "", cv::Size(320, 320));
}
PERF_TEST_P_(DNNTestNetwork, FastNeuralStyle_eccv16)
{
if (backend == DNN_BACKEND_HALIDE)
throw SkipTestException("");
processNet("dnn/fast_neural_style_eccv16_starry_night.t7", "", "", Mat(cv::Size(320, 240), CV_32FC3));
processNet("dnn/fast_neural_style_eccv16_starry_night.t7", "", "", cv::Size(320, 240));
}
PERF_TEST_P_(DNNTestNetwork, Inception_v2_Faster_RCNN)
@ -288,7 +323,7 @@ PERF_TEST_P_(DNNTestNetwork, Inception_v2_Faster_RCNN)
throw SkipTestException("");
processNet("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb",
"dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", "",
Mat(cv::Size(800, 600), CV_32FC3));
cv::Size(800, 600));
}
PERF_TEST_P_(DNNTestNetwork, EfficientDet)
@ -296,12 +331,76 @@ PERF_TEST_P_(DNNTestNetwork, EfficientDet)
if (backend == DNN_BACKEND_HALIDE || target != DNN_TARGET_CPU)
throw SkipTestException("");
Mat sample = imread(findDataFile("dnn/dog416.png"));
resize(sample, sample, Size(512, 512));
Mat inp;
sample.convertTo(inp, CV_32FC3, 1.0/255);
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(512, 512), Scalar(), true);
processNet("dnn/efficientdet-d0.pb", "dnn/efficientdet-d0.pbtxt", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, EfficientNet)
{
Mat sample = imread(findDataFile("dnn/dog416.png"));
Mat inp = blobFromImage(sample, 1.0 / 255.0, Size(224, 224), Scalar(), true);
transposeND(inp, {0, 2, 3, 1}, inp);
processNet("", "dnn/efficientnet-lite4.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, YuNet) {
processNet("", "dnn/onnx/models/yunet-202303.onnx", "", cv::Size(640, 640));
}
PERF_TEST_P_(DNNTestNetwork, SFace) {
processNet("", "dnn/face_recognition_sface_2021dec.onnx", "", cv::Size(112, 112));
}
PERF_TEST_P_(DNNTestNetwork, MPPalm) {
Mat inp(cv::Size(192, 192), CV_32FC3);
randu(inp, 0.0f, 1.0f);
inp = blobFromImage(inp, 1.0, Size(), Scalar(), false);
transposeND(inp, {0, 2, 3, 1}, inp);
processNet("", "dnn/palm_detection_mediapipe_2023feb.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, MPHand) {
Mat inp(cv::Size(224, 224), CV_32FC3);
randu(inp, 0.0f, 1.0f);
inp = blobFromImage(inp, 1.0, Size(), Scalar(), false);
transposeND(inp, {0, 2, 3, 1}, inp);
processNet("", "dnn/handpose_estimation_mediapipe_2023feb.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, MPPose) {
Mat inp(cv::Size(256, 256), CV_32FC3);
randu(inp, 0.0f, 1.0f);
inp = blobFromImage(inp, 1.0, Size(), Scalar(), false);
transposeND(inp, {0, 2, 3, 1}, inp);
processNet("", "dnn/pose_estimation_mediapipe_2023mar.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, PPOCRv3) {
applyTestTag(CV_TEST_TAG_MEMORY_512MB);
processNet("", "dnn/onnx/models/PP_OCRv3_DB_text_det.onnx", "", cv::Size(736, 736));
}
PERF_TEST_P_(DNNTestNetwork, PPHumanSeg) {
processNet("", "dnn/human_segmentation_pphumanseg_2023mar.onnx", "", cv::Size(192, 192));
}
PERF_TEST_P_(DNNTestNetwork, CRNN) {
Mat inp(cv::Size(100, 32), CV_32FC1);
randu(inp, 0.0f, 1.0f);
inp = blobFromImage(inp, 1.0, Size(), Scalar(), false);
processNet("", "dnn/text_recognition_CRNN_EN_2021sep.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, ViTTrack) {
Mat inp1(cv::Size(128, 128), CV_32FC3);
Mat inp2(cv::Size(256, 256), CV_32FC3);
randu(inp1, 0.0f, 1.0f);
randu(inp2, 0.0f, 1.0f);
inp1 = blobFromImage(inp1, 1.0, Size(), Scalar(), false);
inp2 = blobFromImage(inp2, 1.0, Size(), Scalar(), false);
processNet("", "dnn/onnx/models/vitTracker.onnx", "", {std::make_tuple(inp1, "template"), std::make_tuple(inp2, "search")});
}
PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8)
{
@ -310,7 +409,7 @@ PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8)
throw SkipTestException("");
}
Mat inp = imread(findDataFile("dnn/dog416.png"));
resize(inp, inp, Size(320, 320));
inp = blobFromImage(inp, 1.0 / 255.0, Size(320, 320), Scalar(), true);
processNet("", "dnn/tflite/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "", inp);
}

Loading…
Cancel
Save