diff --git a/configure b/configure index 5d68695192..39fabb4ad5 100755 --- a/configure +++ b/configure @@ -2628,6 +2628,7 @@ cbs_vp9_select="cbs" dct_select="rdft" dirac_parse_select="golomb" dnn_suggest="libtensorflow libopenvino" +dnn_deps="swscale" error_resilience_select="me_cmp" faandct_deps="faan" faandct_select="fdctdsp" @@ -3532,7 +3533,6 @@ derain_filter_select="dnn" deshake_filter_select="pixelutils" deshake_opencl_filter_deps="opencl" dilation_opencl_filter_deps="opencl" -dnn_processing_filter_deps="swscale" dnn_processing_filter_select="dnn" drawtext_filter_deps="libfreetype" drawtext_filter_suggest="libfontconfig libfribidi" diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile index e0957073ee..ee08cc5243 100644 --- a/libavfilter/dnn/Makefile +++ b/libavfilter/dnn/Makefile @@ -1,4 +1,5 @@ OBJS-$(CONFIG_DNN) += dnn/dnn_interface.o +OBJS-$(CONFIG_DNN) += dnn/dnn_io_proc.o OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native.o OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layers.o OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layer_avgpool.o diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c index 830ec19c80..14e878b6b8 100644 --- a/libavfilter/dnn/dnn_backend_native.c +++ b/libavfilter/dnn/dnn_backend_native.c @@ -27,6 +27,7 @@ #include "libavutil/avassert.h" #include "dnn_backend_native_layer_conv2d.h" #include "dnn_backend_native_layers.h" +#include "dnn_io_proc.h" #define OFFSET(x) offsetof(NativeContext, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM @@ -69,11 +70,12 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i return DNN_ERROR; } -static DNNReturnType set_input_native(void *model, DNNData *input, const char *input_name) +static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name) { NativeModel *native_model = (NativeModel *)model; NativeContext *ctx = &native_model->ctx; DnnOperand *oprd = NULL; + DNNData input; if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); @@ -97,10 +99,8 @@ static DNNReturnType set_input_native(void *model, DNNData *input, const char *i return DNN_ERROR; } - oprd->dims[0] = 1; - oprd->dims[1] = input->height; - oprd->dims[2] = input->width; - oprd->dims[3] = input->channels; + oprd->dims[1] = frame->height; + oprd->dims[2] = frame->width; av_freep(&oprd->data); oprd->length = calculate_operand_data_length(oprd); @@ -114,7 +114,16 @@ static DNNReturnType set_input_native(void *model, DNNData *input, const char *i return DNN_ERROR; } - input->data = oprd->data; + input.height = oprd->dims[1]; + input.width = oprd->dims[2]; + input.channels = oprd->dims[3]; + input.data = oprd->data; + input.dt = oprd->data_type; + if (native_model->model->pre_proc != NULL) { + native_model->model->pre_proc(frame, &input, native_model->model->userdata); + } else { + proc_from_frame_to_dnn(frame, &input, ctx); + } return DNN_SUCCESS; } @@ -185,6 +194,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio if (av_opt_set_from_string(&native_model->ctx, model->options, NULL, "=", "&") < 0) goto fail; model->model = (void *)native_model; + native_model->model = model; #if !HAVE_PTHREAD_CANCEL if (native_model->ctx.options.conv2d_threads > 1){ @@ -275,11 +285,19 @@ fail: return NULL; } -DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) +DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) { NativeModel *native_model = (NativeModel *)model->model; NativeContext *ctx = &native_model->ctx; int32_t layer; + DNNData output; + + if (nb_output != 1) { + // currently, the filter does not need multiple outputs, + // so we just pending the support until we really need it. + av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); + return DNN_ERROR; + } if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); @@ -317,11 +335,22 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output return DNN_ERROR; } - outputs[i].data = oprd->data; - outputs[i].height = oprd->dims[1]; - outputs[i].width = oprd->dims[2]; - outputs[i].channels = oprd->dims[3]; - outputs[i].dt = oprd->data_type; + output.data = oprd->data; + output.height = oprd->dims[1]; + output.width = oprd->dims[2]; + output.channels = oprd->dims[3]; + output.dt = oprd->data_type; + + if (out_frame->width != output.width || out_frame->height != output.height) { + out_frame->width = output.width; + out_frame->height = output.height; + } else { + if (native_model->model->post_proc != NULL) { + native_model->model->post_proc(out_frame, &output, native_model->model->userdata); + } else { + proc_from_dnn_to_frame(out_frame, &output, ctx); + } + } } return DNN_SUCCESS; diff --git a/libavfilter/dnn/dnn_backend_native.h b/libavfilter/dnn/dnn_backend_native.h index 33634118a8..553438bd22 100644 --- a/libavfilter/dnn/dnn_backend_native.h +++ b/libavfilter/dnn/dnn_backend_native.h @@ -119,6 +119,7 @@ typedef struct NativeContext { // Represents simple feed-forward convolutional network. typedef struct NativeModel{ NativeContext ctx; + DNNModel *model; Layer *layers; int32_t layers_num; DnnOperand *operands; @@ -127,7 +128,7 @@ typedef struct NativeModel{ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); +DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_native(DNNModel **model); diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index 01e1a1d4c8..b1bad3f659 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -24,6 +24,7 @@ */ #include "dnn_backend_openvino.h" +#include "dnn_io_proc.h" #include "libavformat/avio.h" #include "libavutil/avassert.h" #include "libavutil/opt.h" @@ -42,6 +43,7 @@ typedef struct OVContext { typedef struct OVModel{ OVContext ctx; + DNNModel *model; ie_core_t *core; ie_network_t *network; ie_executable_network_t *exe_network; @@ -131,7 +133,7 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input return DNN_ERROR; } -static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input_name) +static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name) { OVModel *ov_model = (OVModel *)model; OVContext *ctx = &ov_model->ctx; @@ -139,10 +141,7 @@ static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input dimensions_t dims; precision_e precision; ie_blob_buffer_t blob_buffer; - - status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request); - if (status != OK) - goto err; + DNNData input; status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob); if (status != OK) @@ -153,23 +152,26 @@ static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input if (status != OK) goto err; - av_assert0(input->channels == dims.dims[1]); - av_assert0(input->height == dims.dims[2]); - av_assert0(input->width == dims.dims[3]); - av_assert0(input->dt == precision_to_datatype(precision)); - status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer); if (status != OK) goto err; - input->data = blob_buffer.buffer; + + input.height = dims.dims[2]; + input.width = dims.dims[3]; + input.channels = dims.dims[1]; + input.data = blob_buffer.buffer; + input.dt = precision_to_datatype(precision); + if (ov_model->model->pre_proc != NULL) { + ov_model->model->pre_proc(frame, &input, ov_model->model->userdata); + } else { + proc_from_frame_to_dnn(frame, &input, ctx); + } return DNN_SUCCESS; err: if (ov_model->input_blob) ie_blob_free(&ov_model->input_blob); - if (ov_model->infer_request) - ie_infer_request_free(&ov_model->infer_request); av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n"); return DNN_ERROR; } @@ -184,7 +186,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, ie_config_t config = {NULL, NULL, NULL}; ie_available_devices_t a_dev; - model = av_malloc(sizeof(DNNModel)); + model = av_mallocz(sizeof(DNNModel)); if (!model){ return NULL; } @@ -192,6 +194,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, ov_model = av_mallocz(sizeof(OVModel)); if (!ov_model) goto err; + ov_model->model = model; ov_model->ctx.class = &dnn_openvino_class; ctx = &ov_model->ctx; @@ -226,6 +229,10 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, goto err; } + status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request); + if (status != OK) + goto err; + model->model = (void *)ov_model; model->set_input = &set_input_ov; model->get_input = &get_input_ov; @@ -238,6 +245,8 @@ err: if (model) av_freep(&model); if (ov_model) { + if (ov_model->infer_request) + ie_infer_request_free(&ov_model->infer_request); if (ov_model->exe_network) ie_exec_network_free(&ov_model->exe_network); if (ov_model->network) @@ -249,7 +258,7 @@ err: return NULL; } -DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) +DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) { char *model_output_name = NULL; char *all_output_names = NULL; @@ -258,8 +267,18 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c ie_blob_buffer_t blob_buffer; OVModel *ov_model = (OVModel *)model->model; OVContext *ctx = &ov_model->ctx; - IEStatusCode status = ie_infer_request_infer(ov_model->infer_request); + IEStatusCode status; size_t model_output_count = 0; + DNNData output; + + if (nb_output != 1) { + // currently, the filter does not need multiple outputs, + // so we just pending the support until we really need it. + av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); + return DNN_ERROR; + } + + status = ie_infer_request_infer(ov_model->infer_request); if (status != OK) { av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n"); return DNN_ERROR; @@ -296,11 +315,21 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c return DNN_ERROR; } - outputs[i].channels = dims.dims[1]; - outputs[i].height = dims.dims[2]; - outputs[i].width = dims.dims[3]; - outputs[i].dt = precision_to_datatype(precision); - outputs[i].data = blob_buffer.buffer; + output.channels = dims.dims[1]; + output.height = dims.dims[2]; + output.width = dims.dims[3]; + output.dt = precision_to_datatype(precision); + output.data = blob_buffer.buffer; + if (out_frame->width != output.width || out_frame->height != output.height) { + out_frame->width = output.width; + out_frame->height = output.height; + } else { + if (ov_model->model->post_proc != NULL) { + ov_model->model->post_proc(out_frame, &output, ov_model->model->userdata); + } else { + proc_from_dnn_to_frame(out_frame, &output, ctx); + } + } } return DNN_SUCCESS; diff --git a/libavfilter/dnn/dnn_backend_openvino.h b/libavfilter/dnn/dnn_backend_openvino.h index f69bc5ca0c..efb349cb49 100644 --- a/libavfilter/dnn/dnn_backend_openvino.h +++ b/libavfilter/dnn/dnn_backend_openvino.h @@ -31,7 +31,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); +DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_ov(DNNModel **model); diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index bac7d8c420..c2d8c06931 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -31,6 +31,7 @@ #include "libavutil/avassert.h" #include "dnn_backend_native_layer_pad.h" #include "dnn_backend_native_layer_maximum.h" +#include "dnn_io_proc.h" #include @@ -40,13 +41,12 @@ typedef struct TFContext { typedef struct TFModel{ TFContext ctx; + DNNModel *model; TF_Graph *graph; TF_Session *session; TF_Status *status; TF_Output input; TF_Tensor *input_tensor; - TF_Tensor **output_tensors; - uint32_t nb_output; } TFModel; static const AVClass dnn_tensorflow_class = { @@ -152,13 +152,19 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input return DNN_SUCCESS; } -static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input_name) +static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name) { TFModel *tf_model = (TFModel *)model; TFContext *ctx = &tf_model->ctx; + DNNData input; TF_SessionOptions *sess_opts; const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); + if (get_input_tf(model, &input, input_name) != DNN_SUCCESS) + return DNN_ERROR; + input.height = frame->height; + input.width = frame->width; + // Input operation tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name); if (!tf_model->input.oper){ @@ -169,12 +175,18 @@ static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - tf_model->input_tensor = allocate_input_tensor(input); + tf_model->input_tensor = allocate_input_tensor(&input); if (!tf_model->input_tensor){ av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); return DNN_ERROR; } - input->data = (float *)TF_TensorData(tf_model->input_tensor); + input.data = (float *)TF_TensorData(tf_model->input_tensor); + + if (tf_model->model->pre_proc != NULL) { + tf_model->model->pre_proc(frame, &input, tf_model->model->userdata); + } else { + proc_from_frame_to_dnn(frame, &input, ctx); + } // session if (tf_model->session){ @@ -591,7 +603,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, DNNModel *model = NULL; TFModel *tf_model = NULL; - model = av_malloc(sizeof(DNNModel)); + model = av_mallocz(sizeof(DNNModel)); if (!model){ return NULL; } @@ -602,6 +614,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, return NULL; } tf_model->ctx.class = &dnn_tensorflow_class; + tf_model->model = model; if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){ if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){ @@ -621,11 +634,20 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, return model; } -DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) +DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) { TF_Output *tf_outputs; TFModel *tf_model = (TFModel *)model->model; TFContext *ctx = &tf_model->ctx; + DNNData output; + TF_Tensor **output_tensors; + + if (nb_output != 1) { + // currently, the filter does not need multiple outputs, + // so we just pending the support until we really need it. + av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); + return DNN_ERROR; + } tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs)); if (tf_outputs == NULL) { @@ -633,18 +655,8 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c return DNN_ERROR; } - if (tf_model->output_tensors) { - for (uint32_t i = 0; i < tf_model->nb_output; ++i) { - if (tf_model->output_tensors[i]) { - TF_DeleteTensor(tf_model->output_tensors[i]); - tf_model->output_tensors[i] = NULL; - } - } - } - av_freep(&tf_model->output_tensors); - tf_model->nb_output = nb_output; - tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors)); - if (!tf_model->output_tensors) { + output_tensors = av_mallocz_array(nb_output, sizeof(*output_tensors)); + if (!output_tensors) { av_freep(&tf_outputs); av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \ return DNN_ERROR; @@ -654,6 +666,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]); if (!tf_outputs[i].oper) { av_freep(&tf_outputs); + av_freep(&output_tensors); av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \ return DNN_ERROR; } @@ -662,22 +675,40 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c TF_SessionRun(tf_model->session, NULL, &tf_model->input, &tf_model->input_tensor, 1, - tf_outputs, tf_model->output_tensors, nb_output, + tf_outputs, output_tensors, nb_output, NULL, 0, NULL, tf_model->status); if (TF_GetCode(tf_model->status) != TF_OK) { av_freep(&tf_outputs); + av_freep(&output_tensors); av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n"); return DNN_ERROR; } for (uint32_t i = 0; i < nb_output; ++i) { - outputs[i].height = TF_Dim(tf_model->output_tensors[i], 1); - outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2); - outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3); - outputs[i].data = TF_TensorData(tf_model->output_tensors[i]); - outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]); + output.height = TF_Dim(output_tensors[i], 1); + output.width = TF_Dim(output_tensors[i], 2); + output.channels = TF_Dim(output_tensors[i], 3); + output.data = TF_TensorData(output_tensors[i]); + output.dt = TF_TensorType(output_tensors[i]); + + if (out_frame->width != output.width || out_frame->height != output.height) { + out_frame->width = output.width; + out_frame->height = output.height; + } else { + if (tf_model->model->post_proc != NULL) { + tf_model->model->post_proc(out_frame, &output, tf_model->model->userdata); + } else { + proc_from_dnn_to_frame(out_frame, &output, ctx); + } + } } + for (uint32_t i = 0; i < nb_output; ++i) { + if (output_tensors[i]) { + TF_DeleteTensor(output_tensors[i]); + } + } + av_freep(&output_tensors); av_freep(&tf_outputs); return DNN_SUCCESS; } @@ -701,15 +732,6 @@ void ff_dnn_free_model_tf(DNNModel **model) if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - if (tf_model->output_tensors) { - for (uint32_t i = 0; i < tf_model->nb_output; ++i) { - if (tf_model->output_tensors[i]) { - TF_DeleteTensor(tf_model->output_tensors[i]); - tf_model->output_tensors[i] = NULL; - } - } - } - av_freep(&tf_model->output_tensors); av_freep(&tf_model); av_freep(model); } diff --git a/libavfilter/dnn/dnn_backend_tf.h b/libavfilter/dnn/dnn_backend_tf.h index 1cf5cc9e76..f379e83d8d 100644 --- a/libavfilter/dnn/dnn_backend_tf.h +++ b/libavfilter/dnn/dnn_backend_tf.h @@ -31,7 +31,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); +DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_tf(DNNModel **model); diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c new file mode 100644 index 0000000000..8ce1959b42 --- /dev/null +++ b/libavfilter/dnn/dnn_io_proc.c @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2020 + * + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "dnn_io_proc.h" +#include "libavutil/imgutils.h" +#include "libswscale/swscale.h" + +DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) +{ + struct SwsContext *sws_ctx; + int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); + if (output->dt != DNN_FLOAT) { + av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n"); + return DNN_ERROR; + } + + switch (frame->format) { + case AV_PIX_FMT_RGB24: + case AV_PIX_FMT_BGR24: + sws_ctx = sws_getContext(frame->width * 3, + frame->height, + AV_PIX_FMT_GRAYF32, + frame->width * 3, + frame->height, + AV_PIX_FMT_GRAY8, + 0, NULL, NULL, NULL); + sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0}, + (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height, + (uint8_t * const*)frame->data, frame->linesize); + sws_freeContext(sws_ctx); + return DNN_SUCCESS; + case AV_PIX_FMT_GRAYF32: + av_image_copy_plane(frame->data[0], frame->linesize[0], + output->data, bytewidth, + bytewidth, frame->height); + return DNN_SUCCESS; + case AV_PIX_FMT_YUV420P: + case AV_PIX_FMT_YUV422P: + case AV_PIX_FMT_YUV444P: + case AV_PIX_FMT_YUV410P: + case AV_PIX_FMT_YUV411P: + case AV_PIX_FMT_GRAY8: + sws_ctx = sws_getContext(frame->width, + frame->height, + AV_PIX_FMT_GRAYF32, + frame->width, + frame->height, + AV_PIX_FMT_GRAY8, + 0, NULL, NULL, NULL); + sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0}, + (const int[4]){frame->width * sizeof(float), 0, 0, 0}, 0, frame->height, + (uint8_t * const*)frame->data, frame->linesize); + sws_freeContext(sws_ctx); + return DNN_SUCCESS; + default: + av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format); + return DNN_ERROR; + } + + return DNN_SUCCESS; +} + +DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) +{ + struct SwsContext *sws_ctx; + int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); + if (input->dt != DNN_FLOAT) { + av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n"); + return DNN_ERROR; + } + + switch (frame->format) { + case AV_PIX_FMT_RGB24: + case AV_PIX_FMT_BGR24: + sws_ctx = sws_getContext(frame->width * 3, + frame->height, + AV_PIX_FMT_GRAY8, + frame->width * 3, + frame->height, + AV_PIX_FMT_GRAYF32, + 0, NULL, NULL, NULL); + sws_scale(sws_ctx, (const uint8_t **)frame->data, + frame->linesize, 0, frame->height, + (uint8_t * const*)(&input->data), + (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); + sws_freeContext(sws_ctx); + break; + case AV_PIX_FMT_GRAYF32: + av_image_copy_plane(input->data, bytewidth, + frame->data[0], frame->linesize[0], + bytewidth, frame->height); + break; + case AV_PIX_FMT_YUV420P: + case AV_PIX_FMT_YUV422P: + case AV_PIX_FMT_YUV444P: + case AV_PIX_FMT_YUV410P: + case AV_PIX_FMT_YUV411P: + case AV_PIX_FMT_GRAY8: + sws_ctx = sws_getContext(frame->width, + frame->height, + AV_PIX_FMT_GRAY8, + frame->width, + frame->height, + AV_PIX_FMT_GRAYF32, + 0, NULL, NULL, NULL); + sws_scale(sws_ctx, (const uint8_t **)frame->data, + frame->linesize, 0, frame->height, + (uint8_t * const*)(&input->data), + (const int [4]){frame->width * sizeof(float), 0, 0, 0}); + sws_freeContext(sws_ctx); + break; + default: + av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format); + return DNN_ERROR; + } + + return DNN_SUCCESS; +} diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h new file mode 100644 index 0000000000..4c7dc7c1a2 --- /dev/null +++ b/libavfilter/dnn/dnn_io_proc.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 + * + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +/** + * @file + * DNN input&output process between AVFrame and DNNData. + */ + + +#ifndef AVFILTER_DNN_DNN_IO_PROC_H +#define AVFILTER_DNN_DNN_IO_PROC_H + +#include "../dnn_interface.h" +#include "libavutil/frame.h" + +DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx); +DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx); + +#endif diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 702c8306e0..6debc50607 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -27,6 +27,7 @@ #define AVFILTER_DNN_INTERFACE_H #include +#include "libavutil/frame.h" typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; @@ -50,17 +51,23 @@ typedef struct DNNModel{ // Gets model input information // Just reuse struct DNNData here, actually the DNNData.data field is not needed. DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); - // Sets model input and output. - // Should be called at least once before model execution. - DNNReturnType (*set_input)(void *model, DNNData *input, const char *input_name); + // Sets model input. + // Should be called every time before model execution. + DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name); + // set the pre process to transfer data from AVFrame to DNNData + // the default implementation within DNN is used if it is not provided by the filter + int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data); + // set the post process to transfer data from DNNData to AVFrame + // the default implementation within DNN is used if it is not provided by the filter + int (*post_proc)(AVFrame *frame_out, DNNData *model_output, void *user_data); } DNNModel; // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. typedef struct DNNModule{ // Loads model and parameters from given file. Returns NULL if it is not possible. DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata); - // Executes model with specified input and output. Returns DNN_ERROR otherwise. - DNNReturnType (*execute_model)(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); + // Executes model with specified output. Returns DNN_ERROR otherwise. + DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); // Frees memory allocated for model. void (*free_model)(DNNModel **model); } DNNModule; diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c index c251d55ee7..a59cd6e941 100644 --- a/libavfilter/vf_derain.c +++ b/libavfilter/vf_derain.c @@ -39,11 +39,8 @@ typedef struct DRContext { DNNBackendType backend_type; DNNModule *dnn_module; DNNModel *model; - DNNData input; - DNNData output; } DRContext; -#define CLIP(x, min, max) (x < min ? min : (x > max ? max : x)) #define OFFSET(x) offsetof(DRContext, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM static const AVOption derain_options[] = { @@ -74,25 +71,6 @@ static int query_formats(AVFilterContext *ctx) return ff_set_common_formats(ctx, formats); } -static int config_inputs(AVFilterLink *inlink) -{ - AVFilterContext *ctx = inlink->dst; - DRContext *dr_context = ctx->priv; - DNNReturnType result; - - dr_context->input.width = inlink->w; - dr_context->input.height = inlink->h; - dr_context->input.channels = 3; - - result = (dr_context->model->set_input)(dr_context->model->model, &dr_context->input, "x"); - if (result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n"); - return AVERROR(EIO); - } - - return 0; -} - static int filter_frame(AVFilterLink *inlink, AVFrame *in) { AVFilterContext *ctx = inlink->dst; @@ -100,43 +78,30 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) DRContext *dr_context = ctx->priv; DNNReturnType dnn_result; const char *model_output_name = "y"; + AVFrame *out; - AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h); + dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x"); + if (dnn_result != DNN_SUCCESS) { + av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); + av_frame_free(&in); + return AVERROR(EIO); + } + + out = ff_get_video_buffer(outlink, outlink->w, outlink->h); if (!out) { av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n"); av_frame_free(&in); return AVERROR(ENOMEM); } - av_frame_copy_props(out, in); - for (int i = 0; i < in->height; i++){ - for(int j = 0; j < in->width * 3; j++){ - int k = i * in->linesize[0] + j; - int t = i * in->width * 3 + j; - ((float *)dr_context->input.data)[t] = in->data[0][k] / 255.0; - } - } - - dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &dr_context->output, &model_output_name, 1); + dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out); if (dnn_result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + av_frame_free(&in); return AVERROR(EIO); } - out->height = dr_context->output.height; - out->width = dr_context->output.width; - outlink->h = dr_context->output.height; - outlink->w = dr_context->output.width; - - for (int i = 0; i < out->height; i++){ - for(int j = 0; j < out->width * 3; j++){ - int k = i * out->linesize[0] + j; - int t = i * out->width * 3 + j; - out->data[0][k] = CLIP((int)((((float *)dr_context->output.data)[t]) * 255), 0, 255); - } - } - av_frame_free(&in); return ff_filter_frame(outlink, out); @@ -146,7 +111,6 @@ static av_cold int init(AVFilterContext *ctx) { DRContext *dr_context = ctx->priv; - dr_context->input.dt = DNN_FLOAT; dr_context->dnn_module = ff_get_dnn_module(dr_context->backend_type); if (!dr_context->dnn_module) { av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n"); @@ -184,7 +148,6 @@ static const AVFilterPad derain_inputs[] = { { .name = "default", .type = AVMEDIA_TYPE_VIDEO, - .config_props = config_inputs, .filter_frame = filter_frame, }, { NULL } diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index f120bf9df4..d7462bc828 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -46,12 +46,6 @@ typedef struct DnnProcessingContext { DNNModule *dnn_module; DNNModel *model; - // input & output of the model at execution time - DNNData input; - DNNData output; - - struct SwsContext *sws_gray8_to_grayf32; - struct SwsContext *sws_grayf32_to_gray8; struct SwsContext *sws_uv_scale; int sws_uv_height; } DnnProcessingContext; @@ -103,7 +97,7 @@ static av_cold int init(AVFilterContext *context) return AVERROR(EINVAL); } - ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, NULL); + ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, ctx); if (!ctx->model) { av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n"); return AVERROR(EINVAL); @@ -148,6 +142,10 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin model_input->width, inlink->w); return AVERROR(EIO); } + if (model_input->dt != DNN_FLOAT) { + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32.\n"); + return AVERROR(EIO); + } switch (fmt) { case AV_PIX_FMT_RGB24: @@ -156,20 +154,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } - if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); - return AVERROR(EIO); - } - return 0; - case AV_PIX_FMT_GRAY8: - if (model_input->channels != 1) { - LOG_FORMAT_CHANNEL_MISMATCH(); - return AVERROR(EIO); - } - if (model_input->dt != DNN_UINT8) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n"); - return AVERROR(EIO); - } return 0; case AV_PIX_FMT_GRAYF32: case AV_PIX_FMT_YUV420P: @@ -181,10 +165,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } - if (model_input->dt != DNN_FLOAT) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n"); - return AVERROR(EIO); - } return 0; default: av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt)); @@ -213,74 +193,24 @@ static int config_input(AVFilterLink *inlink) return check; } - ctx->input.width = inlink->w; - ctx->input.height = inlink->h; - ctx->input.channels = model_input.channels; - ctx->input.dt = model_input.dt; - - result = (ctx->model->set_input)(ctx->model->model, - &ctx->input, ctx->model_inputname); - if (result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n"); - return AVERROR(EIO); - } - return 0; } -static int prepare_sws_context(AVFilterLink *outlink) +static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) +{ + const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); + av_assert0(desc); + return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; +} + +static int prepare_uv_scale(AVFilterLink *outlink) { AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; AVFilterLink *inlink = context->inputs[0]; enum AVPixelFormat fmt = inlink->format; - DNNDataType input_dt = ctx->input.dt; - DNNDataType output_dt = ctx->output.dt; - - switch (fmt) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (input_dt == DNN_FLOAT) { - ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3, - inlink->h, - AV_PIX_FMT_GRAY8, - inlink->w * 3, - inlink->h, - AV_PIX_FMT_GRAYF32, - 0, NULL, NULL, NULL); - } - if (output_dt == DNN_FLOAT) { - ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3, - outlink->h, - AV_PIX_FMT_GRAYF32, - outlink->w * 3, - outlink->h, - AV_PIX_FMT_GRAY8, - 0, NULL, NULL, NULL); - } - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - av_assert0(input_dt == DNN_FLOAT); - av_assert0(output_dt == DNN_FLOAT); - ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w, - inlink->h, - AV_PIX_FMT_GRAY8, - inlink->w, - inlink->h, - AV_PIX_FMT_GRAYF32, - 0, NULL, NULL, NULL); - ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w, - outlink->h, - AV_PIX_FMT_GRAYF32, - outlink->w, - outlink->h, - AV_PIX_FMT_GRAY8, - 0, NULL, NULL, NULL); + if (isPlanarYUV(fmt)) { if (inlink->w != outlink->w || inlink->h != outlink->h) { const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt); int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); @@ -292,10 +222,6 @@ static int prepare_sws_context(AVFilterLink *outlink) SWS_BICUBIC, NULL, NULL, NULL); ctx->sws_uv_height = sws_src_h; } - return 0; - default: - //do nothing - break; } return 0; @@ -306,120 +232,34 @@ static int config_output(AVFilterLink *outlink) AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; DNNReturnType result; + AVFilterLink *inlink = context->inputs[0]; + AVFrame *out = NULL; - // have a try run in case that the dnn model resize the frame - result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); - if (result != DNN_SUCCESS){ - av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname); + if (result != DNN_SUCCESS) { + av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); return AVERROR(EIO); } - outlink->w = ctx->output.width; - outlink->h = ctx->output.height; - - prepare_sws_context(outlink); - - return 0; -} - -static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame) -{ - int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); - DNNData *dnn_input = &ctx->input; - - switch (frame->format) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (dnn_input->dt == DNN_FLOAT) { - sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, - 0, frame->height, (uint8_t * const*)(&dnn_input->data), - (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); - } else { - av_assert0(dnn_input->dt == DNN_UINT8); - av_image_copy_plane(dnn_input->data, bytewidth, - frame->data[0], frame->linesize[0], - bytewidth, frame->height); - } - return 0; - case AV_PIX_FMT_GRAY8: - case AV_PIX_FMT_GRAYF32: - av_image_copy_plane(dnn_input->data, bytewidth, - frame->data[0], frame->linesize[0], - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, - 0, frame->height, (uint8_t * const*)(&dnn_input->data), - (const int [4]){frame->width * sizeof(float), 0, 0, 0}); - return 0; - default: + // have a try run in case that the dnn model resize the frame + out = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + if (result != DNN_SUCCESS){ + av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); return AVERROR(EIO); } - return 0; -} + outlink->w = out->width; + outlink->h = out->height; -static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame) -{ - int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); - DNNData *dnn_output = &ctx->output; - - switch (frame->format) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (dnn_output->dt == DNN_FLOAT) { - sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, - (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, - 0, frame->height, (uint8_t * const*)frame->data, frame->linesize); - - } else { - av_assert0(dnn_output->dt == DNN_UINT8); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - } - return 0; - case AV_PIX_FMT_GRAY8: - // it is possible that data type of dnn output is float32, - // need to add support for such case when needed. - av_assert0(dnn_output->dt == DNN_UINT8); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_GRAYF32: - av_assert0(dnn_output->dt == DNN_FLOAT); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, - (const int[4]){frame->width * sizeof(float), 0, 0, 0}, - 0, frame->height, (uint8_t * const*)frame->data, frame->linesize); - return 0; - default: - return AVERROR(EIO); - } + av_frame_free(&fake_in); + av_frame_free(&out); + prepare_uv_scale(outlink); return 0; } -static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) -{ - const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); - av_assert0(desc); - return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; -} - static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in) { const AVPixFmtDescriptor *desc; @@ -453,11 +293,9 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) DNNReturnType dnn_result; AVFrame *out; - copy_from_frame_to_dnn(ctx, in); - - dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); - if (dnn_result != DNN_SUCCESS){ - av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname); + if (dnn_result != DNN_SUCCESS) { + av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); av_frame_free(&in); return AVERROR(EIO); } @@ -467,9 +305,15 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) av_frame_free(&in); return AVERROR(ENOMEM); } - av_frame_copy_props(out, in); - copy_from_dnn_to_frame(ctx, out); + + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + if (dnn_result != DNN_SUCCESS){ + av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + av_frame_free(&in); + av_frame_free(&out); + return AVERROR(EIO); + } if (isPlanarYUV(in->format)) copy_uv_planes(ctx, out, in); @@ -482,8 +326,6 @@ static av_cold void uninit(AVFilterContext *ctx) { DnnProcessingContext *context = ctx->priv; - sws_freeContext(context->sws_gray8_to_grayf32); - sws_freeContext(context->sws_grayf32_to_gray8); sws_freeContext(context->sws_uv_scale); if (context->dnn_module) diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c index 445777f0c6..2eda8c3219 100644 --- a/libavfilter/vf_sr.c +++ b/libavfilter/vf_sr.c @@ -41,11 +41,10 @@ typedef struct SRContext { DNNBackendType backend_type; DNNModule *dnn_module; DNNModel *model; - DNNData input; - DNNData output; int scale_factor; - struct SwsContext *sws_contexts[3]; - int sws_slice_h, sws_input_linesize, sws_output_linesize; + struct SwsContext *sws_uv_scale; + int sws_uv_height; + struct SwsContext *sws_pre_scale; } SRContext; #define OFFSET(x) offsetof(SRContext, x) @@ -87,11 +86,6 @@ static av_cold int init(AVFilterContext *context) return AVERROR(EIO); } - sr_context->input.dt = DNN_FLOAT; - sr_context->sws_contexts[0] = NULL; - sr_context->sws_contexts[1] = NULL; - sr_context->sws_contexts[2] = NULL; - return 0; } @@ -111,95 +105,63 @@ static int query_formats(AVFilterContext *context) return ff_set_common_formats(context, formats_list); } -static int config_props(AVFilterLink *inlink) +static int config_output(AVFilterLink *outlink) { - AVFilterContext *context = inlink->dst; - SRContext *sr_context = context->priv; - AVFilterLink *outlink = context->outputs[0]; + AVFilterContext *context = outlink->src; + SRContext *ctx = context->priv; DNNReturnType result; - int sws_src_h, sws_src_w, sws_dst_h, sws_dst_w; + AVFilterLink *inlink = context->inputs[0]; + AVFrame *out = NULL; const char *model_output_name = "y"; - sr_context->input.width = inlink->w * sr_context->scale_factor; - sr_context->input.height = inlink->h * sr_context->scale_factor; - sr_context->input.channels = 1; - - result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x"); - if (result != DNN_SUCCESS){ - av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); + AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->model->set_input)(ctx->model->model, fake_in, "x"); + if (result != DNN_SUCCESS) { + av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); return AVERROR(EIO); } - result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); + // have a try run in case that the dnn model resize the frame + out = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); if (result != DNN_SUCCESS){ av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); return AVERROR(EIO); } - if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){ - sr_context->input.width = inlink->w; - sr_context->input.height = inlink->h; - result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x"); - if (result != DNN_SUCCESS){ - av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); - return AVERROR(EIO); - } - result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); - if (result != DNN_SUCCESS){ - av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); - return AVERROR(EIO); - } - sr_context->scale_factor = 0; - } - outlink->h = sr_context->output.height; - outlink->w = sr_context->output.width; - sr_context->sws_contexts[1] = sws_getContext(sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAY8, - sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAYF32, - 0, NULL, NULL, NULL); - sr_context->sws_input_linesize = sr_context->input.width << 2; - sr_context->sws_contexts[2] = sws_getContext(sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAYF32, - sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAY8, - 0, NULL, NULL, NULL); - sr_context->sws_output_linesize = sr_context->output.width << 2; - if (!sr_context->sws_contexts[1] || !sr_context->sws_contexts[2]){ - av_log(context, AV_LOG_ERROR, "could not create SwsContext for conversions\n"); - return AVERROR(ENOMEM); - } - if (sr_context->scale_factor){ - sr_context->sws_contexts[0] = sws_getContext(inlink->w, inlink->h, inlink->format, - outlink->w, outlink->h, outlink->format, - SWS_BICUBIC, NULL, NULL, NULL); - if (!sr_context->sws_contexts[0]){ - av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n"); - return AVERROR(ENOMEM); - } - sr_context->sws_slice_h = inlink->h; - } else { + if (fake_in->width != out->width || fake_in->height != out->height) { + //espcn + outlink->w = out->width; + outlink->h = out->height; if (inlink->format != AV_PIX_FMT_GRAY8){ const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); - sws_src_h = AV_CEIL_RSHIFT(sr_context->input.height, desc->log2_chroma_h); - sws_src_w = AV_CEIL_RSHIFT(sr_context->input.width, desc->log2_chroma_w); - sws_dst_h = AV_CEIL_RSHIFT(sr_context->output.height, desc->log2_chroma_h); - sws_dst_w = AV_CEIL_RSHIFT(sr_context->output.width, desc->log2_chroma_w); - - sr_context->sws_contexts[0] = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8, - sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8, - SWS_BICUBIC, NULL, NULL, NULL); - if (!sr_context->sws_contexts[0]){ - av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n"); - return AVERROR(ENOMEM); - } - sr_context->sws_slice_h = sws_src_h; + int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); + int sws_src_w = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w); + int sws_dst_h = AV_CEIL_RSHIFT(outlink->h, desc->log2_chroma_h); + int sws_dst_w = AV_CEIL_RSHIFT(outlink->w, desc->log2_chroma_w); + ctx->sws_uv_scale = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8, + sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8, + SWS_BICUBIC, NULL, NULL, NULL); + ctx->sws_uv_height = sws_src_h; } + } else { + //srcnn + outlink->w = out->width * ctx->scale_factor; + outlink->h = out->height * ctx->scale_factor; + ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format, + outlink->w, outlink->h, outlink->format, + SWS_BICUBIC, NULL, NULL, NULL); } + av_frame_free(&fake_in); + av_frame_free(&out); return 0; } static int filter_frame(AVFilterLink *inlink, AVFrame *in) { AVFilterContext *context = inlink->dst; - SRContext *sr_context = context->priv; + SRContext *ctx = context->priv; AVFilterLink *outlink = context->outputs[0]; AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h); DNNReturnType dnn_result; @@ -211,45 +173,44 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) return AVERROR(ENOMEM); } av_frame_copy_props(out, in); - out->height = sr_context->output.height; - out->width = sr_context->output.width; - if (sr_context->scale_factor){ - sws_scale(sr_context->sws_contexts[0], (const uint8_t **)in->data, in->linesize, - 0, sr_context->sws_slice_h, out->data, out->linesize); - sws_scale(sr_context->sws_contexts[1], (const uint8_t **)out->data, out->linesize, - 0, out->height, (uint8_t * const*)(&sr_context->input.data), - (const int [4]){sr_context->sws_input_linesize, 0, 0, 0}); + if (ctx->sws_pre_scale) { + sws_scale(ctx->sws_pre_scale, + (const uint8_t **)in->data, in->linesize, 0, in->height, + out->data, out->linesize); + dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x"); } else { - if (sr_context->sws_contexts[0]){ - sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 1), in->linesize + 1, - 0, sr_context->sws_slice_h, out->data + 1, out->linesize + 1); - sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 2), in->linesize + 2, - 0, sr_context->sws_slice_h, out->data + 2, out->linesize + 2); - } + dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x"); + } - sws_scale(sr_context->sws_contexts[1], (const uint8_t **)in->data, in->linesize, - 0, in->height, (uint8_t * const*)(&sr_context->input.data), - (const int [4]){sr_context->sws_input_linesize, 0, 0, 0}); + if (dnn_result != DNN_SUCCESS) { + av_frame_free(&in); + av_frame_free(&out); + av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); + return AVERROR(EIO); } - av_frame_free(&in); - dnn_result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); if (dnn_result != DNN_SUCCESS){ - av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); + av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n"); + av_frame_free(&in); + av_frame_free(&out); return AVERROR(EIO); } - sws_scale(sr_context->sws_contexts[2], (const uint8_t *[4]){(const uint8_t *)sr_context->output.data, 0, 0, 0}, - (const int[4]){sr_context->sws_output_linesize, 0, 0, 0}, - 0, out->height, (uint8_t * const*)out->data, out->linesize); + if (ctx->sws_uv_scale) { + sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1, + 0, ctx->sws_uv_height, out->data + 1, out->linesize + 1); + sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 2), in->linesize + 2, + 0, ctx->sws_uv_height, out->data + 2, out->linesize + 2); + } + av_frame_free(&in); return ff_filter_frame(outlink, out); } static av_cold void uninit(AVFilterContext *context) { - int i; SRContext *sr_context = context->priv; if (sr_context->dnn_module){ @@ -257,16 +218,14 @@ static av_cold void uninit(AVFilterContext *context) av_freep(&sr_context->dnn_module); } - for (i = 0; i < 3; ++i){ - sws_freeContext(sr_context->sws_contexts[i]); - } + sws_freeContext(sr_context->sws_uv_scale); + sws_freeContext(sr_context->sws_pre_scale); } static const AVFilterPad sr_inputs[] = { { .name = "default", .type = AVMEDIA_TYPE_VIDEO, - .config_props = config_props, .filter_frame = filter_frame, }, { NULL } @@ -275,6 +234,7 @@ static const AVFilterPad sr_inputs[] = { static const AVFilterPad sr_outputs[] = { { .name = "default", + .config_props = config_output, .type = AVMEDIA_TYPE_VIDEO, }, { NULL }