diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c index e9c3619abd..8dec3b671a 100644 --- a/libavfilter/dnn/dnn_interface.c +++ b/libavfilter/dnn/dnn_interface.c @@ -120,11 +120,16 @@ void *ff_dnn_child_next(DnnContext *obj, void *prev) { return NULL; } -const AVClass *ff_dnn_child_class_iterate(void **iter) +const AVClass *ff_dnn_child_class_iterate_with_mask(void **iter, uint32_t backend_mask) { - uintptr_t i = (uintptr_t) *iter; + for (uintptr_t i = (uintptr_t)*iter; i < FF_ARRAY_ELEMS(dnn_backend_info_list); i++) { + if (i > 0) { + const DNNModule *module = dnn_backend_info_list[i].module; + + if (!(module->type & backend_mask)) + continue; + } - if (i < FF_ARRAY_ELEMS(dnn_backend_info_list)) { *iter = (void *)(i + 1); return dnn_backend_info_list[i].class; } diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h index b52b55a90d..42a4719997 100644 --- a/libavfilter/dnn_filter_common.h +++ b/libavfilter/dnn_filter_common.h @@ -26,6 +26,12 @@ #include "dnn_interface.h" +#define DNN_FILTER_CHILD_CLASS_ITERATE(name, backend_mask) \ + static const AVClass *name##_child_class_iterate(void **iter) \ + { \ + return ff_dnn_child_class_iterate_with_mask(iter, (backend_mask)); \ + } + #define AVFILTER_DNN_DEFINE_CLASS_EXT(name, desc, options) \ static const AVClass name##_class = { \ .class_name = desc, \ @@ -34,10 +40,11 @@ .version = LIBAVUTIL_VERSION_INT, \ .category = AV_CLASS_CATEGORY_FILTER, \ .child_next = ff_dnn_filter_child_next, \ - .child_class_iterate = ff_dnn_child_class_iterate, \ + .child_class_iterate = name##_child_class_iterate, \ } -#define AVFILTER_DNN_DEFINE_CLASS(fname) \ +#define AVFILTER_DNN_DEFINE_CLASS(fname, backend_mask) \ + DNN_FILTER_CHILD_CLASS_ITERATE(fname, backend_mask) \ AVFILTER_DNN_DEFINE_CLASS_EXT(fname, #fname, fname##_options) void *ff_dnn_filter_child_next(void *obj, void *prev); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index dd603534b2..697b9f3318 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -32,7 +32,11 @@ #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!') -typedef enum {DNN_TF = 1, DNN_OV, DNN_TH} DNNBackendType; +typedef enum { + DNN_TF = 1, + DNN_OV = 1 << 1, + DNN_TH = 1 << 2 +} DNNBackendType; typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType; @@ -190,7 +194,7 @@ const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx); void ff_dnn_init_child_class(DnnContext *ctx); void *ff_dnn_child_next(DnnContext *obj, void *prev); -const AVClass *ff_dnn_child_class_iterate(void **iter); +const AVClass *ff_dnn_child_class_iterate_with_mask(void **iter, uint32_t backend_mask); static inline int dnn_get_width_idx_by_layout(DNNLayout layout) { diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c index 7f665b73ab..5cefca6b55 100644 --- a/libavfilter/vf_derain.c +++ b/libavfilter/vf_derain.c @@ -49,7 +49,7 @@ static const AVOption derain_options[] = { { NULL } }; -AVFILTER_DNN_DEFINE_CLASS(derain); +AVFILTER_DNN_DEFINE_CLASS(derain, DNN_TF); static int filter_frame(AVFilterLink *inlink, AVFrame *in) { diff --git a/libavfilter/vf_dnn_classify.c b/libavfilter/vf_dnn_classify.c index 965779a8ab..f6d3678796 100644 --- a/libavfilter/vf_dnn_classify.c +++ b/libavfilter/vf_dnn_classify.c @@ -56,7 +56,7 @@ static const AVOption dnn_classify_options[] = { { NULL } }; -AVFILTER_DNN_DEFINE_CLASS(dnn_classify); +AVFILTER_DNN_DEFINE_CLASS(dnn_classify, DNN_OV); static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx) { diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c index 1830bae181..2a277d4169 100644 --- a/libavfilter/vf_dnn_detect.c +++ b/libavfilter/vf_dnn_detect.c @@ -84,7 +84,7 @@ static const AVOption dnn_detect_options[] = { { NULL } }; -AVFILTER_DNN_DEFINE_CLASS(dnn_detect); +AVFILTER_DNN_DEFINE_CLASS(dnn_detect, DNN_TF | DNN_OV); static inline float sigmoid(float x) { return 1.f / (1.f + exp(-x)); diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index 9a1dd2a356..7c0f84ec80 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -57,7 +57,7 @@ static const AVOption dnn_processing_options[] = { { NULL } }; -AVFILTER_DNN_DEFINE_CLASS(dnn_processing); +AVFILTER_DNN_DEFINE_CLASS(dnn_processing, DNN_TF | DNN_OV | DNN_TH); static av_cold int init(AVFilterContext *context) { diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c index f14c0c0cd3..3bfca7f042 100644 --- a/libavfilter/vf_sr.c +++ b/libavfilter/vf_sr.c @@ -53,7 +53,7 @@ static const AVOption sr_options[] = { { NULL } }; -AVFILTER_DNN_DEFINE_CLASS(sr); +AVFILTER_DNN_DEFINE_CLASS(sr, DNN_TF); static av_cold int init(AVFilterContext *context) {