@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf ;
}
static DNNReturnType set_input_output_tf ( void * model , DNNData * input , DNNData * output )
static DNNReturnType set_input_output_tf ( void * model , DNNData * input , const char * input_name , DNNData * output , const char * output_name )
{
TFModel * tf_model = ( TFModel * ) model ;
int64_t input_dims [ ] = { 1 , input - > height , input - > width , input - > channels } ;
@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
const TF_Operation * init_op = TF_GraphOperationByName ( tf_model - > graph , " init " ) ;
TF_Tensor * output_tensor ;
// Input operation should be named 'x'
tf_model - > input . oper = TF_GraphOperationByName ( tf_model - > graph , " x " ) ;
// Input operation
tf_model - > input . oper = TF_GraphOperationByName ( tf_model - > graph , input_name ) ;
if ( ! tf_model - > input . oper ) {
return DNN_ERROR ;
}
@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
}
input - > data = ( float * ) TF_TensorData ( tf_model - > input_tensor ) ;
// Output operation should be named 'y'
tf_model - > output . oper = TF_GraphOperationByName ( tf_model - > graph , " y " ) ;
// Output operation
tf_model - > output . oper = TF_GraphOperationByName ( tf_model - > graph , output_name ) ;
if ( ! tf_model - > output . oper ) {
return DNN_ERROR ;
}