@ -30,9 +30,11 @@
# include "libavutil/time.h"
# include "libavutil/time.h"
# include "libavutil/avstring.h"
# include "libavutil/avstring.h"
# include "libavutil/detection_bbox.h"
# include "libavutil/detection_bbox.h"
# include "libavutil/fifo.h"
typedef enum {
typedef enum {
DDMT_SSD
DDMT_SSD ,
DDMT_YOLOV1V2 ,
} DNNDetectionModelType ;
} DNNDetectionModelType ;
typedef struct DnnDetectContext {
typedef struct DnnDetectContext {
@ -43,6 +45,15 @@ typedef struct DnnDetectContext {
char * * labels ;
char * * labels ;
int label_count ;
int label_count ;
DNNDetectionModelType model_type ;
DNNDetectionModelType model_type ;
int cell_w ;
int cell_h ;
int nb_classes ;
AVFifo * bboxes_fifo ;
int scale_width ;
int scale_height ;
char * anchors_str ;
float * anchors ;
int nb_anchor ;
} DnnDetectContext ;
} DnnDetectContext ;
# define OFFSET(x) offsetof(DnnDetectContext, dnnctx.x)
# define OFFSET(x) offsetof(DnnDetectContext, dnnctx.x)
@ -61,11 +72,218 @@ static const AVOption dnn_detect_options[] = {
{ " labels " , " path to labels file " , OFFSET2 ( labels_filename ) , AV_OPT_TYPE_STRING , { . str = NULL } , 0 , 0 , FLAGS } ,
{ " labels " , " path to labels file " , OFFSET2 ( labels_filename ) , AV_OPT_TYPE_STRING , { . str = NULL } , 0 , 0 , FLAGS } ,
{ " model_type " , " DNN detection model type " , OFFSET2 ( model_type ) , AV_OPT_TYPE_INT , { . i64 = DDMT_SSD } , INT_MIN , INT_MAX , FLAGS , " model_type " } ,
{ " model_type " , " DNN detection model type " , OFFSET2 ( model_type ) , AV_OPT_TYPE_INT , { . i64 = DDMT_SSD } , INT_MIN , INT_MAX , FLAGS , " model_type " } ,
{ " ssd " , " output shape [1, 1, N, 7] " , 0 , AV_OPT_TYPE_CONST , { . i64 = DDMT_SSD } , 0 , 0 , FLAGS , " model_type " } ,
{ " ssd " , " output shape [1, 1, N, 7] " , 0 , AV_OPT_TYPE_CONST , { . i64 = DDMT_SSD } , 0 , 0 , FLAGS , " model_type " } ,
{ " yolo " , " output shape [1, N*Cx*Cy*DetectionBox] " , 0 , AV_OPT_TYPE_CONST , { . i64 = DDMT_YOLOV1V2 } , 0 , 0 , FLAGS , " model_type " } ,
{ " cell_w " , " cell width " , OFFSET2 ( cell_w ) , AV_OPT_TYPE_INT , { . i64 = 0 } , 0 , INTMAX_MAX , FLAGS } ,
{ " cell_h " , " cell height " , OFFSET2 ( cell_h ) , AV_OPT_TYPE_INT , { . i64 = 0 } , 0 , INTMAX_MAX , FLAGS } ,
{ " nb_classes " , " The number of class " , OFFSET2 ( nb_classes ) , AV_OPT_TYPE_INT , { . i64 = 0 } , 0 , INTMAX_MAX , FLAGS } ,
{ " anchors " , " anchors, splited by '&' " , OFFSET2 ( anchors_str ) , AV_OPT_TYPE_STRING , { . str = NULL } , 0 , 0 , FLAGS } ,
{ NULL }
{ NULL }
} ;
} ;
AVFILTER_DEFINE_CLASS ( dnn_detect ) ;
AVFILTER_DEFINE_CLASS ( dnn_detect ) ;
static int dnn_detect_get_label_id ( int nb_classes , int cell_size , float * label_data )
{
float max_prob = 0 ;
int label_id = 0 ;
for ( int i = 0 ; i < nb_classes ; i + + ) {
if ( label_data [ i * cell_size ] > max_prob ) {
max_prob = label_data [ i * cell_size ] ;
label_id = i ;
}
}
return label_id ;
}
static int dnn_detect_parse_anchors ( char * anchors_str , float * * anchors )
{
char * saveptr = NULL , * token ;
float * anchors_buf ;
int nb_anchor = 0 , i = 0 ;
while ( anchors_str [ i ] ! = ' \0 ' ) {
if ( anchors_str [ i ] = = ' & ' )
nb_anchor + + ;
i + + ;
}
nb_anchor + + ;
anchors_buf = av_mallocz ( nb_anchor * sizeof ( * anchors ) ) ;
if ( ! anchors_buf ) {
return 0 ;
}
for ( int i = 0 ; i < nb_anchor ; i + + ) {
token = av_strtok ( anchors_str , " & " , & saveptr ) ;
anchors_buf [ i ] = strtof ( token , NULL ) ;
anchors_str = NULL ;
}
* anchors = anchors_buf ;
return nb_anchor ;
}
/* Calculate Intersection Over Union */
static float dnn_detect_IOU ( AVDetectionBBox * bbox1 , AVDetectionBBox * bbox2 )
{
float overlapping_width = FFMIN ( bbox1 - > x + bbox1 - > w , bbox2 - > x + bbox2 - > w ) - FFMAX ( bbox1 - > x , bbox2 - > x ) ;
float overlapping_height = FFMIN ( bbox1 - > y + bbox1 - > h , bbox2 - > y + bbox2 - > h ) - FFMAX ( bbox1 - > y , bbox2 - > y ) ;
float intersection_area =
( overlapping_width < 0 | | overlapping_height < 0 ) ? 0 : overlapping_height * overlapping_width ;
float union_area = bbox1 - > w * bbox1 - > h + bbox2 - > w * bbox2 - > h - intersection_area ;
return intersection_area / union_area ;
}
static int dnn_detect_parse_yolo_output ( AVFrame * frame , DNNData * output , int output_index ,
AVFilterContext * filter_ctx )
{
DnnDetectContext * ctx = filter_ctx - > priv ;
float conf_threshold = ctx - > confidence ;
int detection_boxes , box_size , cell_w , cell_h , scale_w , scale_h ;
int nb_classes = ctx - > nb_classes ;
float * output_data = output [ output_index ] . data ;
float * anchors = ctx - > anchors ;
AVDetectionBBox * bbox ;
if ( ctx - > model_type = = DDMT_YOLOV1V2 ) {
cell_w = ctx - > cell_w ;
cell_h = ctx - > cell_h ;
scale_w = cell_w ;
scale_h = cell_h ;
}
box_size = nb_classes + 5 ;
if ( ! cell_h | | ! cell_w ) {
av_log ( filter_ctx , AV_LOG_ERROR , " cell_w and cell_h are detected \n " ) ;
return AVERROR ( EINVAL ) ;
}
if ( ! nb_classes ) {
av_log ( filter_ctx , AV_LOG_ERROR , " nb_classes is not set \n " ) ;
return AVERROR ( EINVAL ) ;
}
if ( ! anchors ) {
av_log ( filter_ctx , AV_LOG_ERROR , " anchors is not set \n " ) ;
return AVERROR ( EINVAL ) ;
}
if ( output [ output_index ] . channels * output [ output_index ] . width *
output [ output_index ] . height % ( box_size * cell_w * cell_h ) ) {
av_log ( filter_ctx , AV_LOG_ERROR , " wrong cell_w, cell_h or nb_classes \n " ) ;
return AVERROR ( EINVAL ) ;
}
detection_boxes = output [ output_index ] . channels *
output [ output_index ] . height *
output [ output_index ] . width / box_size / cell_w / cell_h ;
/**
* find all candidate bbox
* yolo output can be reshaped to [ B , N * D , Cx , Cy ]
* Detection box ' D ' has format [ ` x ` , ` y ` , ` h ` , ` w ` , ` box_score ` , ` class_no_1 ` , . . . , ]
* */
for ( int box_id = 0 ; box_id < detection_boxes ; box_id + + ) {
for ( int cx = 0 ; cx < cell_w ; cx + + )
for ( int cy = 0 ; cy < cell_h ; cy + + ) {
float x , y , w , h , conf ;
float * detection_boxes_data ;
int label_id ;
detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h ;
conf = detection_boxes_data [ cy * cell_w + cx + 4 * cell_w * cell_h ] ;
if ( conf < conf_threshold ) {
continue ;
}
x = detection_boxes_data [ cy * cell_w + cx ] ;
y = detection_boxes_data [ cy * cell_w + cx + cell_w * cell_h ] ;
w = detection_boxes_data [ cy * cell_w + cx + 2 * cell_w * cell_h ] ;
h = detection_boxes_data [ cy * cell_w + cx + 3 * cell_w * cell_h ] ;
label_id = dnn_detect_get_label_id ( ctx - > nb_classes , cell_w * cell_h ,
detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h ) ;
conf = conf * detection_boxes_data [ cy * cell_w + cx + ( label_id + 5 ) * cell_w * cell_h ] ;
bbox = av_mallocz ( sizeof ( * bbox ) ) ;
if ( ! bbox )
return AVERROR ( ENOMEM ) ;
bbox - > w = exp ( w ) * anchors [ box_id * 2 ] * frame - > width / scale_w ;
bbox - > h = exp ( h ) * anchors [ box_id * 2 + 1 ] * frame - > height / scale_h ;
bbox - > x = ( cx + x ) / cell_w * frame - > width - bbox - > w / 2 ;
bbox - > y = ( cy + y ) / cell_h * frame - > height - bbox - > h / 2 ;
bbox - > detect_confidence = av_make_q ( ( int ) ( conf * 10000 ) , 10000 ) ;
if ( ctx - > labels & & label_id < ctx - > label_count ) {
av_strlcpy ( bbox - > detect_label , ctx - > labels [ label_id ] , sizeof ( bbox - > detect_label ) ) ;
} else {
snprintf ( bbox - > detect_label , sizeof ( bbox - > detect_label ) , " %d " , label_id ) ;
}
if ( av_fifo_write ( ctx - > bboxes_fifo , & bbox , 1 ) < 0 ) {
av_freep ( & bbox ) ;
return AVERROR ( ENOMEM ) ;
}
}
}
return 0 ;
}
static int dnn_detect_fill_side_data ( AVFrame * frame , AVFilterContext * filter_ctx )
{
DnnDetectContext * ctx = filter_ctx - > priv ;
float conf_threshold = ctx - > confidence ;
AVDetectionBBox * bbox ;
int nb_bboxes = 0 ;
AVDetectionBBoxHeader * header ;
if ( av_fifo_can_read ( ctx - > bboxes_fifo ) = = 0 ) {
av_log ( filter_ctx , AV_LOG_VERBOSE , " nothing detected in this frame. \n " ) ;
return 0 ;
}
/* remove overlap bboxes */
for ( int i = 0 ; i < av_fifo_can_read ( ctx - > bboxes_fifo ) ; i + + ) {
av_fifo_peek ( ctx - > bboxes_fifo , & bbox , 1 , i ) ;
for ( int j = 0 ; j < av_fifo_can_read ( ctx - > bboxes_fifo ) ; j + + ) {
AVDetectionBBox * overlap_bbox ;
av_fifo_peek ( ctx - > bboxes_fifo , & overlap_bbox , 1 , j ) ;
if ( ! strcmp ( bbox - > detect_label , overlap_bbox - > detect_label ) & &
av_cmp_q ( bbox - > detect_confidence , overlap_bbox - > detect_confidence ) < 0 & &
dnn_detect_IOU ( bbox , overlap_bbox ) > = conf_threshold ) {
bbox - > classify_count = - 1 ; // bad result
nb_bboxes + + ;
break ;
}
}
}
nb_bboxes = av_fifo_can_read ( ctx - > bboxes_fifo ) - nb_bboxes ;
header = av_detection_bbox_create_side_data ( frame , nb_bboxes ) ;
if ( ! header ) {
av_log ( filter_ctx , AV_LOG_ERROR , " failed to create side data with %d bounding boxes \n " , nb_bboxes ) ;
return - 1 ;
}
av_strlcpy ( header - > source , ctx - > dnnctx . model_filename , sizeof ( header - > source ) ) ;
while ( av_fifo_can_read ( ctx - > bboxes_fifo ) ) {
AVDetectionBBox * candidate_bbox ;
av_fifo_read ( ctx - > bboxes_fifo , & candidate_bbox , 1 ) ;
if ( nb_bboxes > 0 & & candidate_bbox - > classify_count ! = - 1 ) {
bbox = av_get_detection_bbox ( header , header - > nb_bboxes - nb_bboxes ) ;
memcpy ( bbox , candidate_bbox , sizeof ( * bbox ) ) ;
nb_bboxes - - ;
}
av_freep ( & candidate_bbox ) ;
}
return 0 ;
}
static int dnn_detect_post_proc_yolo ( AVFrame * frame , DNNData * output , AVFilterContext * filter_ctx )
{
int ret = 0 ;
ret = dnn_detect_parse_yolo_output ( frame , output , 0 , filter_ctx ) ;
if ( ret < 0 )
return ret ;
ret = dnn_detect_fill_side_data ( frame , filter_ctx ) ;
if ( ret < 0 )
return ret ;
return 0 ;
}
static int dnn_detect_post_proc_ssd ( AVFrame * frame , DNNData * output , AVFilterContext * filter_ctx )
static int dnn_detect_post_proc_ssd ( AVFrame * frame , DNNData * output , AVFilterContext * filter_ctx )
{
{
DnnDetectContext * ctx = filter_ctx - > priv ;
DnnDetectContext * ctx = filter_ctx - > priv ;
@ -158,6 +376,10 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, AVFilterCont
if ( ret < 0 )
if ( ret < 0 )
return ret ;
return ret ;
break ;
break ;
case DDMT_YOLOV1V2 :
ret = dnn_detect_post_proc_yolo ( frame , output , filter_ctx ) ;
if ( ret < 0 )
return ret ;
}
}
return 0 ;
return 0 ;
@ -356,11 +578,22 @@ static av_cold int dnn_detect_init(AVFilterContext *context)
ret = check_output_nb ( ctx , dnn_ctx - > backend_type , dnn_ctx - > nb_outputs ) ;
ret = check_output_nb ( ctx , dnn_ctx - > backend_type , dnn_ctx - > nb_outputs ) ;
if ( ret < 0 )
if ( ret < 0 )
return ret ;
return ret ;
ctx - > bboxes_fifo = av_fifo_alloc2 ( 1 , sizeof ( AVDetectionBBox * ) , AV_FIFO_FLAG_AUTO_GROW ) ;
if ( ! ctx - > bboxes_fifo )
return AVERROR ( ENOMEM ) ;
ff_dnn_set_detect_post_proc ( & ctx - > dnnctx , dnn_detect_post_proc ) ;
ff_dnn_set_detect_post_proc ( & ctx - > dnnctx , dnn_detect_post_proc ) ;
if ( ctx - > labels_filename ) {
if ( ctx - > labels_filename ) {
return read_detect_label_file ( context ) ;
return read_detect_label_file ( context ) ;
}
}
if ( ctx - > anchors_str ) {
ret = dnn_detect_parse_anchors ( ctx - > anchors_str , & ctx - > anchors ) ;
if ( ! ctx - > anchors ) {
av_log ( context , AV_LOG_ERROR , " failed to parse anchors_str \n " ) ;
return AVERROR ( EINVAL ) ;
}
ctx - > nb_anchor = ret ;
}
return 0 ;
return 0 ;
}
}
@ -460,7 +693,14 @@ static int dnn_detect_activate(AVFilterContext *filter_ctx)
static av_cold void dnn_detect_uninit ( AVFilterContext * context )
static av_cold void dnn_detect_uninit ( AVFilterContext * context )
{
{
DnnDetectContext * ctx = context - > priv ;
DnnDetectContext * ctx = context - > priv ;
AVDetectionBBox * bbox ;
ff_dnn_uninit ( & ctx - > dnnctx ) ;
ff_dnn_uninit ( & ctx - > dnnctx ) ;
while ( av_fifo_can_read ( ctx - > bboxes_fifo ) ) {
av_fifo_read ( ctx - > bboxes_fifo , & bbox , 1 ) ;
av_freep ( & bbox ) ;
}
av_fifo_freep2 ( & ctx - > bboxes_fifo ) ;
av_freep ( & ctx - > anchors ) ;
free_detect_labels ( ctx ) ;
free_detect_labels ( ctx ) ;
}
}