@ -38,7 +38,6 @@
# include "dnn_io_proc.h"
# include "dnn_backend_common.h"
# include "safe_queue.h"
# include "queue.h"
# include <tensorflow/c/c_api.h>
typedef struct TFOptions {
@ -59,6 +58,7 @@ typedef struct TFModel{
TF_Status * status ;
SafeQueue * request_queue ;
Queue * inference_queue ;
Queue * task_queue ;
} TFModel ;
/**
@ -75,7 +75,7 @@ typedef struct TFInferRequest {
typedef struct TFRequestItem {
TFInferRequest * infer_request ;
InferenceItem * inference ;
// further properties will be added later for async
DNNAsyncExecModule exec_module ;
} TFRequestItem ;
# define OFFSET(x) offsetof(TFContext, x)
@ -89,6 +89,7 @@ static const AVOption dnn_tensorflow_options[] = {
AVFILTER_DEFINE_CLASS ( dnn_tensorflow ) ;
static DNNReturnType execute_model_tf ( TFRequestItem * request , Queue * inference_queue ) ;
static void infer_completion_callback ( void * args ) ;
static void free_buffer ( void * data , size_t length )
{
@ -886,6 +887,9 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_
av_freep ( & item ) ;
goto err ;
}
item - > exec_module . start_inference = & tf_start_inference ;
item - > exec_module . callback = & infer_completion_callback ;
item - > exec_module . args = item ;
if ( ff_safe_queue_push_back ( tf_model - > request_queue , item ) < 0 ) {
av_freep ( & item - > infer_request ) ;
@ -899,6 +903,11 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_
goto err ;
}
tf_model - > task_queue = ff_queue_create ( ) ;
if ( ! tf_model - > task_queue ) {
goto err ;
}
model - > model = tf_model ;
model - > get_input = & get_input_tf ;
model - > get_output = & get_output_tf ;
@ -1061,7 +1070,6 @@ static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_q
{
TFModel * tf_model ;
TFContext * ctx ;
TFInferRequest * infer_request ;
InferenceItem * inference ;
TaskItem * task ;
@ -1074,23 +1082,14 @@ static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_q
tf_model = task - > model ;
ctx = & tf_model - > ctx ;
if ( task - > async ) {
avpriv_report_missing_feature ( ctx , " Async execution not supported " ) ;
if ( fill_model_input_tf ( tf_model , request ) ! = DNN_SUCCESS ) {
return DNN_ERROR ;
} else {
if ( fill_model_input_tf ( tf_model , request ) ! = DNN_SUCCESS ) {
return DNN_ERROR ;
}
}
infer_request = request - > infer_request ;
TF_SessionRun ( tf_model - > session , NULL ,
infer_request - > tf_input , & infer_request - > input_tensor , 1 ,
infer_request - > tf_outputs , infer_request - > output_tensors ,
task - > nb_output , NULL , 0 , NULL ,
tf_model - > status ) ;
if ( TF_GetCode ( tf_model - > status ) ! = TF_OK ) {
tf_free_request ( infer_request ) ;
av_log ( ctx , AV_LOG_ERROR , " Failed to run session when executing model \n " ) ;
if ( task - > async ) {
return ff_dnn_start_inference_async ( ctx , & request - > exec_module ) ;
} else {
if ( tf_start_inference ( request ) ! = DNN_SUCCESS ) {
return DNN_ERROR ;
}
infer_completion_callback ( request ) ;
@ -1127,6 +1126,83 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNExecBaseParams *
return execute_model_tf ( request , tf_model - > inference_queue ) ;
}
DNNReturnType ff_dnn_execute_model_async_tf ( const DNNModel * model , DNNExecBaseParams * exec_params ) {
TFModel * tf_model = model - > model ;
TFContext * ctx = & tf_model - > ctx ;
TaskItem * task ;
TFRequestItem * request ;
if ( ff_check_exec_params ( ctx , DNN_TF , model - > func_type , exec_params ) ! = 0 ) {
return DNN_ERROR ;
}
task = av_malloc ( sizeof ( * task ) ) ;
if ( ! task ) {
av_log ( ctx , AV_LOG_ERROR , " unable to alloc memory for task item. \n " ) ;
return DNN_ERROR ;
}
if ( ff_dnn_fill_task ( task , exec_params , tf_model , 1 , 1 ) ! = DNN_SUCCESS ) {
av_freep ( & task ) ;
return DNN_ERROR ;
}
if ( ff_queue_push_back ( tf_model - > task_queue , task ) < 0 ) {
av_freep ( & task ) ;
av_log ( ctx , AV_LOG_ERROR , " unable to push back task_queue. \n " ) ;
return DNN_ERROR ;
}
if ( extract_inference_from_task ( task , tf_model - > inference_queue ) ! = DNN_SUCCESS ) {
av_log ( ctx , AV_LOG_ERROR , " unable to extract inference from task. \n " ) ;
return DNN_ERROR ;
}
request = ff_safe_queue_pop_front ( tf_model - > request_queue ) ;
if ( ! request ) {
av_log ( ctx , AV_LOG_ERROR , " unable to get infer request. \n " ) ;
return DNN_ERROR ;
}
return execute_model_tf ( request , tf_model - > inference_queue ) ;
}
DNNAsyncStatusType ff_dnn_get_async_result_tf ( const DNNModel * model , AVFrame * * in , AVFrame * * out )
{
TFModel * tf_model = model - > model ;
return ff_dnn_get_async_result_common ( tf_model - > task_queue , in , out ) ;
}
DNNReturnType ff_dnn_flush_tf ( const DNNModel * model )
{
TFModel * tf_model = model - > model ;
TFContext * ctx = & tf_model - > ctx ;
TFRequestItem * request ;
DNNReturnType ret ;
if ( ff_queue_size ( tf_model - > inference_queue ) = = 0 ) {
// no pending task need to flush
return DNN_SUCCESS ;
}
request = ff_safe_queue_pop_front ( tf_model - > request_queue ) ;
if ( ! request ) {
av_log ( ctx , AV_LOG_ERROR , " unable to get infer request. \n " ) ;
return DNN_ERROR ;
}
ret = fill_model_input_tf ( tf_model , request ) ;
if ( ret ! = DNN_SUCCESS ) {
av_log ( ctx , AV_LOG_ERROR , " Failed to fill model input. \n " ) ;
if ( ff_safe_queue_push_back ( tf_model - > request_queue , request ) < 0 ) {
av_freep ( & request - > infer_request ) ;
av_freep ( & request ) ;
}
return ret ;
}
return ff_dnn_start_inference_async ( ctx , & request - > exec_module ) ;
}
void ff_dnn_free_model_tf ( DNNModel * * model )
{
TFModel * tf_model ;
@ -1135,6 +1211,7 @@ void ff_dnn_free_model_tf(DNNModel **model)
tf_model = ( * model ) - > model ;
while ( ff_safe_queue_size ( tf_model - > request_queue ) ! = 0 ) {
TFRequestItem * item = ff_safe_queue_pop_front ( tf_model - > request_queue ) ;
ff_dnn_async_module_cleanup ( & item - > exec_module ) ;
tf_free_request ( item - > infer_request ) ;
av_freep ( & item - > infer_request ) ;
av_freep ( & item ) ;
@ -1147,6 +1224,14 @@ void ff_dnn_free_model_tf(DNNModel **model)
}
ff_queue_destroy ( tf_model - > inference_queue ) ;
while ( ff_queue_size ( tf_model - > task_queue ) ! = 0 ) {
TaskItem * item = ff_queue_pop_front ( tf_model - > task_queue ) ;
av_frame_free ( & item - > in_frame ) ;
av_frame_free ( & item - > out_frame ) ;
av_freep ( & item ) ;
}
ff_queue_destroy ( tf_model - > task_queue ) ;
if ( tf_model - > graph ) {
TF_DeleteGraph ( tf_model - > graph ) ;
}