@ -29,14 +29,20 @@
# include "dnn_backend_native_layer_depth2space.h"
# include "libavformat/avio.h"
# include "libavutil/avassert.h"
# include "../internal.h"
# include "dnn_backend_native_layer_pad.h"
# include "dnn_backend_native_layer_maximum.h"
# include "dnn_io_proc.h"
# include <tensorflow/c/c_api.h>
typedef struct TFOptions {
char * sess_config ;
} TFOptions ;
typedef struct TFContext {
const AVClass * class ;
TFOptions options ;
} TFContext ;
typedef struct TFModel {
@ -47,14 +53,15 @@ typedef struct TFModel{
TF_Status * status ;
} TFModel ;
static const AVClass dnn_tensorflow_class = {
. class_name = " dnn_tensorflow " ,
. item_name = av_default_item_name ,
. option = NULL ,
. version = LIBAVUTIL_VERSION_INT ,
. category = AV_CLASS_CATEGORY_FILTER ,
# define OFFSET(x) offsetof(TFContext, x)
# define FLAGS AV_OPT_FLAG_FILTERING_PARAM
static const AVOption dnn_tensorflow_options [ ] = {
{ " sess_config " , " config for SessionOptions " , OFFSET ( options . sess_config ) , AV_OPT_TYPE_STRING , { . str = NULL } , 0 , 0 , FLAGS } ,
{ NULL }
} ;
AVFILTER_DEFINE_CLASS ( dnn_tensorflow ) ;
static DNNReturnType execute_model_tf ( const DNNModel * model , const char * input_name , AVFrame * in_frame ,
const char * * output_names , uint32_t nb_output , AVFrame * out_frame ,
int do_ioproc ) ;
@ -194,10 +201,64 @@ static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename
TF_ImportGraphDefOptions * graph_opts ;
TF_SessionOptions * sess_opts ;
const TF_Operation * init_op ;
uint8_t * sess_config = NULL ;
int sess_config_length = 0 ;
// prepare the sess config data
if ( tf_model - > ctx . options . sess_config ! = NULL ) {
/*
tf_model - > ctx . options . sess_config is hex to present the serialized proto
required by TF_SetConfig below , so we need to first generate the serialized
proto in a python script , the following is a script example to generate
serialized proto which specifies one GPU , we can change the script to add
more options .
import tensorflow as tf
gpu_options = tf . GPUOptions ( visible_device_list = ' 0 ' )
config = tf . ConfigProto ( gpu_options = gpu_options )
s = config . SerializeToString ( )
b = ' ' . join ( " %02x " % int ( ord ( b ) ) for b in s [ : : - 1 ] )
print ( ' 0 x % s ' % b )
the script output looks like : 0xab . . . cd , and then pass 0xab . . . cd to sess_config .
*/
char tmp [ 3 ] ;
tmp [ 2 ] = ' \0 ' ;
if ( strncmp ( tf_model - > ctx . options . sess_config , " 0x " , 2 ) ! = 0 ) {
av_log ( ctx , AV_LOG_ERROR , " sess_config should start with '0x' \n " ) ;
return DNN_ERROR ;
}
sess_config_length = strlen ( tf_model - > ctx . options . sess_config ) ;
if ( sess_config_length % 2 ! = 0 ) {
av_log ( ctx , AV_LOG_ERROR , " the length of sess_config is not even (%s), "
" please re-generate the config. \n " ,
tf_model - > ctx . options . sess_config ) ;
return DNN_ERROR ;
}
sess_config_length - = 2 ; //ignore the first '0x'
sess_config_length / = 2 ; //get the data length in byte
sess_config = av_malloc ( sess_config_length ) ;
if ( ! sess_config ) {
av_log ( ctx , AV_LOG_ERROR , " failed to allocate memory \n " ) ;
return DNN_ERROR ;
}
for ( int i = 0 ; i < sess_config_length ; i + + ) {
int index = 2 + ( sess_config_length - 1 - i ) * 2 ;
tmp [ 0 ] = tf_model - > ctx . options . sess_config [ index ] ;
tmp [ 1 ] = tf_model - > ctx . options . sess_config [ index + 1 ] ;
sess_config [ i ] = strtol ( tmp , NULL , 16 ) ;
}
}
graph_def = read_graph ( model_filename ) ;
if ( ! graph_def ) {
av_log ( ctx , AV_LOG_ERROR , " Failed to read model \" %s \" graph \n " , model_filename ) ;
av_freep ( & sess_config ) ;
return DNN_ERROR ;
}
tf_model - > graph = TF_NewGraph ( ) ;
@ -210,11 +271,23 @@ static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename
TF_DeleteGraph ( tf_model - > graph ) ;
TF_DeleteStatus ( tf_model - > status ) ;
av_log ( ctx , AV_LOG_ERROR , " Failed to import serialized graph to model graph \n " ) ;
av_freep ( & sess_config ) ;
return DNN_ERROR ;
}
init_op = TF_GraphOperationByName ( tf_model - > graph , " init " ) ;
sess_opts = TF_NewSessionOptions ( ) ;
if ( sess_config ) {
TF_SetConfig ( sess_opts , sess_config , sess_config_length , tf_model - > status ) ;
av_freep ( & sess_config ) ;
if ( TF_GetCode ( tf_model - > status ) ! = TF_OK ) {
av_log ( ctx , AV_LOG_ERROR , " Failed to set config for sess options with %s \n " ,
tf_model - > ctx . options . sess_config ) ;
return DNN_ERROR ;
}
}
tf_model - > session = TF_NewSession ( tf_model - > graph , sess_opts , tf_model - > status ) ;
TF_DeleteSessionOptions ( sess_opts ) ;
if ( TF_GetCode ( tf_model - > status ) ! = TF_OK )
@ -609,6 +682,15 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
tf_model - > ctx . class = & dnn_tensorflow_class ;
tf_model - > model = model ;
//parse options
av_opt_set_defaults ( & tf_model - > ctx ) ;
if ( av_opt_set_from_string ( & tf_model - > ctx , options , NULL , " = " , " & " ) < 0 ) {
av_log ( & tf_model - > ctx , AV_LOG_ERROR , " Failed to parse options \" %s \" \n " , options ) ;
av_freep ( & tf_model ) ;
av_freep ( & model ) ;
return NULL ;
}
if ( load_tf_model ( tf_model , model_filename ) ! = DNN_SUCCESS ) {
if ( load_native_model ( tf_model , model_filename ) ! = DNN_SUCCESS ) {
av_freep ( & tf_model ) ;