diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index 7923e1db69..76cc037b94 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -29,14 +29,20 @@ #include "dnn_backend_native_layer_depth2space.h" #include "libavformat/avio.h" #include "libavutil/avassert.h" +#include "../internal.h" #include "dnn_backend_native_layer_pad.h" #include "dnn_backend_native_layer_maximum.h" #include "dnn_io_proc.h" #include +typedef struct TFOptions{ + char *sess_config; +} TFOptions; + typedef struct TFContext { const AVClass *class; + TFOptions options; } TFContext; typedef struct TFModel{ @@ -47,14 +53,15 @@ typedef struct TFModel{ TF_Status *status; } TFModel; -static const AVClass dnn_tensorflow_class = { - .class_name = "dnn_tensorflow", - .item_name = av_default_item_name, - .option = NULL, - .version = LIBAVUTIL_VERSION_INT, - .category = AV_CLASS_CATEGORY_FILTER, +#define OFFSET(x) offsetof(TFContext, x) +#define FLAGS AV_OPT_FLAG_FILTERING_PARAM +static const AVOption dnn_tensorflow_options[] = { + { "sess_config", "config for SessionOptions", OFFSET(options.sess_config), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS }, + { NULL } }; +AVFILTER_DEFINE_CLASS(dnn_tensorflow); + static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame, const char **output_names, uint32_t nb_output, AVFrame *out_frame, int do_ioproc); @@ -194,10 +201,64 @@ static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename TF_ImportGraphDefOptions *graph_opts; TF_SessionOptions *sess_opts; const TF_Operation *init_op; + uint8_t *sess_config = NULL; + int sess_config_length = 0; + + // prepare the sess config data + if (tf_model->ctx.options.sess_config != NULL) { + /* + tf_model->ctx.options.sess_config is hex to present the serialized proto + required by TF_SetConfig below, so we need to first generate the serialized + proto in a python script, the following is a script example to generate + serialized proto which specifies one GPU, we can change the script to add + more options. + + import tensorflow as tf + gpu_options = tf.GPUOptions(visible_device_list='0') + config = tf.ConfigProto(gpu_options=gpu_options) + s = config.SerializeToString() + b = ''.join("%02x" % int(ord(b)) for b in s[::-1]) + print('0x%s' % b) + + the script output looks like: 0xab...cd, and then pass 0xab...cd to sess_config. + */ + char tmp[3]; + tmp[2] = '\0'; + + if (strncmp(tf_model->ctx.options.sess_config, "0x", 2) != 0) { + av_log(ctx, AV_LOG_ERROR, "sess_config should start with '0x'\n"); + return DNN_ERROR; + } + + sess_config_length = strlen(tf_model->ctx.options.sess_config); + if (sess_config_length % 2 != 0) { + av_log(ctx, AV_LOG_ERROR, "the length of sess_config is not even (%s), " + "please re-generate the config.\n", + tf_model->ctx.options.sess_config); + return DNN_ERROR; + } + + sess_config_length -= 2; //ignore the first '0x' + sess_config_length /= 2; //get the data length in byte + + sess_config = av_malloc(sess_config_length); + if (!sess_config) { + av_log(ctx, AV_LOG_ERROR, "failed to allocate memory\n"); + return DNN_ERROR; + } + + for (int i = 0; i < sess_config_length; i++) { + int index = 2 + (sess_config_length - 1 - i) * 2; + tmp[0] = tf_model->ctx.options.sess_config[index]; + tmp[1] = tf_model->ctx.options.sess_config[index + 1]; + sess_config[i] = strtol(tmp, NULL, 16); + } + } graph_def = read_graph(model_filename); if (!graph_def){ av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename); + av_freep(&sess_config); return DNN_ERROR; } tf_model->graph = TF_NewGraph(); @@ -210,11 +271,23 @@ static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename TF_DeleteGraph(tf_model->graph); TF_DeleteStatus(tf_model->status); av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n"); + av_freep(&sess_config); return DNN_ERROR; } init_op = TF_GraphOperationByName(tf_model->graph, "init"); sess_opts = TF_NewSessionOptions(); + + if (sess_config) { + TF_SetConfig(sess_opts, sess_config, sess_config_length,tf_model->status); + av_freep(&sess_config); + if (TF_GetCode(tf_model->status) != TF_OK) { + av_log(ctx, AV_LOG_ERROR, "Failed to set config for sess options with %s\n", + tf_model->ctx.options.sess_config); + return DNN_ERROR; + } + } + tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status); TF_DeleteSessionOptions(sess_opts); if (TF_GetCode(tf_model->status) != TF_OK) @@ -609,6 +682,15 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, tf_model->ctx.class = &dnn_tensorflow_class; tf_model->model = model; + //parse options + av_opt_set_defaults(&tf_model->ctx); + if (av_opt_set_from_string(&tf_model->ctx, options, NULL, "=", "&") < 0) { + av_log(&tf_model->ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options); + av_freep(&tf_model); + av_freep(&model); + return NULL; + } + if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){ if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){ av_freep(&tf_model);