@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf ;
}
static DNNReturnType set_input_output_tf ( void * model , DNNData * input , const char * input_name , const char * * output_names , uint32_t nb_out put)
static TF_Tensor * allocate_input_tensor ( const DNNInputData * in put)
{
TFModel * tf_model = ( TFModel * ) model ;
TF_DataType dt ;
size_t size ;
int64_t input_dims [ ] = { 1 , input - > height , input - > width , input - > channels } ;
switch ( input - > dt ) {
case DNN_FLOAT :
dt = TF_FLOAT ;
size = sizeof ( float ) ;
break ;
case DNN_UINT8 :
dt = TF_UINT8 ;
size = sizeof ( char ) ;
break ;
default :
av_assert0 ( ! " should not reach here " ) ;
}
return TF_AllocateTensor ( dt , input_dims , 4 ,
input_dims [ 1 ] * input_dims [ 2 ] * input_dims [ 3 ] * size ) ;
}
static DNNReturnType set_input_output_tf ( void * model , DNNInputData * input , const char * input_name , const char * * output_names , uint32_t nb_output )
{
TFModel * tf_model = ( TFModel * ) model ;
TF_SessionOptions * sess_opts ;
const TF_Operation * init_op = TF_GraphOperationByName ( tf_model - > graph , " init " ) ;
@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
if ( tf_model - > input_tensor ) {
TF_DeleteTensor ( tf_model - > input_tensor ) ;
}
tf_model - > input_tensor = TF_AllocateTensor ( TF_FLOAT , input_dims , 4 ,
input_dims [ 1 ] * input_dims [ 2 ] * input_dims [ 3 ] * sizeof ( float ) ) ;
tf_model - > input_tensor = allocate_input_tensor ( input ) ;
if ( ! tf_model - > input_tensor ) {
return DNN_ERROR ;
}