Merge pull request #25463 from WanliZhong:ocvface2YuNet

Change opencv_face_detector related tests and samples from caffe to onnx #25463

Part of https://github.com/opencv/opencv/issues/25314

This PR aims to change the tests related to opencv_face_detector from caffe framework to onnx. Tests in `test_int8_layer.cpp` and `test_caffe_importer.cpp` will be removed in https://github.com/opencv/opencv/pull/25323

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [ ] The feature is well documented and sample code can be built with the project CMake
pull/25409/head
Wanli 7 months ago committed by GitHub
parent 4422bc9a7f
commit b637e3a66e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 18
      modules/dnn/misc/face_detector_accuracy.py
  2. 49
      modules/dnn/misc/python/test/test_dnn.py
  3. 1
      modules/dnn/misc/quantize_face_detector.py
  4. 5
      modules/dnn/perf/perf_net.cpp
  5. 8
      modules/dnn/test/test_backends.cpp
  6. 3
      modules/dnn/test/test_misc.cpp
  7. 14
      samples/dnn/models.yml

@ -15,8 +15,8 @@ from pycocotools.cocoeval import COCOeval
parser = argparse.ArgumentParser(
description='Evaluate OpenCV face detection algorithms '
'using COCO evaluation tool, http://cocodataset.org/#detections-eval')
parser.add_argument('--proto', help='Path to .prototxt of Caffe model or .pbtxt of TensorFlow graph')
parser.add_argument('--model', help='Path to .caffemodel trained in Caffe or .pb from TensorFlow')
parser.add_argument('--proto', help='Path to .pbtxt of TensorFlow graph')
parser.add_argument('--model', help='Path to .onnx of ONNX model or .pb from TensorFlow')
parser.add_argument('--cascade', help='Optional path to trained Haar cascade as '
'an additional model for evaluation')
parser.add_argument('--ann', help='Path to text file with ground truth annotations')
@ -139,7 +139,7 @@ with open('annotations.json', 'wt') as f:
### Obtain detections ##########################################################
detections = []
if args.proto and args.model:
if args.proto and args.model and args.model.endswith('.pb'):
net = cv.dnn.readNet(args.proto, args.model)
def detect(img, imageId):
@ -162,6 +162,18 @@ if args.proto and args.model:
addDetection(detections, imageId, x, y, w, h, score=confidence)
elif args.model and args.model.endswith('.onnx'):
net = cv.FaceDetectorYN.create(args.model, "", (320, 320), 0.3, 0.45, 5000)
def detect(img, imageId):
net.setInputSize((img.shape[1], img.shape[0]))
faces = net.detect(img)
if faces[1] is not None:
for idx, face in enumerate(faces[1]):
left, top, width, height = face[0], face[1], face[2], face[3]
addDetection(detections, imageId, left, top, width, height, score=face[-1])
elif args.cascade:
cascade = cv.CascadeClassifier(args.cascade)

@ -286,41 +286,42 @@ class dnn_test(NewOpenCVTests):
def test_face_detection(self):
proto = self.find_dnn_file('dnn/opencv_face_detector.prototxt')
model = self.find_dnn_file('dnn/opencv_face_detector.caffemodel', required=False)
if proto is None or model is None:
raise unittest.SkipTest("Missing DNN test files (dnn/opencv_face_detector.{prototxt/caffemodel}). Verify OPENCV_DNN_TEST_DATA_PATH configuration parameter.")
model = self.find_dnn_file('dnn/onnx/models/yunet-202303.onnx', required=False)
img = self.get_sample('gpu/lbpcascade/er.png')
blob = cv.dnn.blobFromImage(img, mean=(104, 177, 123), swapRB=False, crop=False)
ref = [[0, 1, 0.99520785, 0.80997437, 0.16379407, 0.87996572, 0.26685631],
[0, 1, 0.9934696, 0.2831718, 0.50738752, 0.345781, 0.5985168],
[0, 1, 0.99096733, 0.13629119, 0.24892329, 0.19756334, 0.3310290],
[0, 1, 0.98977017, 0.23901358, 0.09084064, 0.29902688, 0.1769477],
[0, 1, 0.97203469, 0.67965847, 0.06876482, 0.73999709, 0.1513494],
[0, 1, 0.95097077, 0.51901293, 0.45863652, 0.5777427, 0.5347801]]
ref = [[1, 339.62445, 35.32416, 30.754604, 40.202126, 0.9302596],
[1, 140.63962, 255.55545, 32.832615, 41.767395, 0.916015],
[1, 68.39314, 126.74046, 30.29324, 39.14823, 0.90639645],
[1, 119.57139, 48.482178, 30.600697, 40.485996, 0.906021],
[1, 259.0921, 229.30713, 31.088186, 39.74022, 0.90490955],
[1, 405.69778, 87.28158, 33.393406, 42.96226, 0.8996978]]
print('\n')
for backend, target in self.dnnBackendsAndTargets:
printParams(backend, target)
net = cv.dnn.readNet(proto, model)
net.setPreferableBackend(backend)
net.setPreferableTarget(target)
net.setInput(blob)
out = net.forward().reshape(-1, 7)
scoresDiff = 4e-3 if target in [cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD] else 1e-5
iouDiff = 2e-2 if target in [cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD] else 1e-4
net = cv.FaceDetectorYN.create(
model=model,
config="",
input_size=img.shape[:2],
score_threshold=0.3,
nms_threshold=0.45,
top_k=5000,
backend_id=backend,
target_id=target
)
out = net.detect(img)
out = out[1]
out = out.reshape(-1, 15)
ref = np.array(ref, np.float32)
refClassIds, testClassIds = ref[:, 1], out[:, 1]
refScores, testScores = ref[:, 2], out[:, 2]
refBoxes, testBoxes = ref[:, 3:], out[:, 3:]
refClassIds, testClassIds = ref[:, 0], np.ones(out.shape[0], np.float32)
refScores, testScores = ref[:, -1], out[:, -1]
refBoxes, testBoxes = ref[:, 1:5], out[:, 0:4]
normAssertDetections(self, refClassIds, refScores, refBoxes, testClassIds,
testScores, testBoxes, 0.5, scoresDiff, iouDiff)
testScores, testBoxes, 0.5)
def test_async(self):
timeout = 10*1000*10**6 # in nanoseconds (10 sec)

@ -2,6 +2,7 @@ from __future__ import print_function
import sys
import argparse
import cv2 as cv
assert cv.__version__ < "5.0", "Caffe importer is deprecated and removed from OpenCV 5.0"
import tensorflow as tf
import numpy as np
import struct

@ -158,11 +158,6 @@ PERF_TEST_P_(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi_faster_4_stages.prototxt", cv::Size(368, 368));
}
PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
{
processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt", cv::Size(300, 300));
}
PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
{
applyTestTag(CV_TEST_TAG_DEBUG_VERYLONG);

@ -372,12 +372,12 @@ TEST_P(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
expectNoFallbacksFromCUDA(net);
}
TEST_P(DNNTestNetwork, opencv_face_detector)
TEST_P(DNNTestNetwork, YuNet)
{
Mat img = imread(findDataFile("gpu/lbpcascade/er.png"));
Mat inp = blobFromImage(img, 1.0, Size(), Scalar(104.0, 177.0, 123.0), false, false);
processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt",
inp, "detection_out");
resize(img, img, Size(320, 320));
Mat inp = blobFromImage(img);
processNet("dnn/onnx/models/yunet-202303.onnx", "", inp);
expectNoFallbacksFromIE(net);
}

@ -247,9 +247,6 @@ TEST(readNet, Regression)
Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt"),
findDataFile("dnn/squeezenet_v1.1.caffemodel", false));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/opencv_face_detector.caffemodel", false),
findDataFile("dnn/opencv_face_detector.prototxt"));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/tiny-yolo-voc.cfg"),
findDataFile("dnn/tiny-yolo-voc.weights", false));
EXPECT_FALSE(net.empty());

@ -4,20 +4,6 @@
# Object detection models.
################################################################################
# OpenCV's face detection network
opencv_fd:
load_info:
url: "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel"
sha1: "15aa726b4d46d9f023526d85537db81cbc8dd566"
model: "opencv_face_detector.caffemodel"
config: "opencv_face_detector.prototxt"
mean: [104, 177, 123]
scale: 1.0
width: 300
height: 300
rgb: false
sample: "object_detection"
# YOLOv8 object detection family from ultralytics (https://github.com/ultralytics/ultralytics)
# Might be used for all YOLOv8n YOLOv8s YOLOv8m YOLOv8l and YOLOv8x
yolov8x:

Loading…
Cancel
Save