You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

81 lines
2.6 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import cv2
import numpy as np
from PIL import Image
import paddle
class BasePredictor(object):
def __init__(self):
pass
def build_inference_model(self):
if paddle.in_dynamic_mode():
# todo self.model = build_model(self.cfg)
pass
else:
place = paddle.get_device()
self.exe = paddle.static.Executor(place)
file_names = os.listdir(self.weight_path)
for file_name in file_names:
if file_name.find('model') > -1:
model_file = file_name
elif file_name.find('param') > -1:
param_file = file_name
self.program, self.feed_names, self.fetch_targets = paddle.static.load_inference_model(
self.weight_path,
executor=self.exe,
model_filename=model_file,
params_filename=param_file)
def base_forward(self, inputs):
if paddle.in_dynamic_mode():
out = self.model(inputs)
else:
feed_dict = {}
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, (list, tuple)):
for i, feed_name in enumerate(self.feed_names):
feed_dict[feed_name] = inputs[i]
else:
feed_dict[self.feed_names[0]] = inputs
out = self.exe.run(self.program,
fetch_list=self.fetch_targets,
feed=feed_dict)
return out
def is_image(self, input):
try:
if isinstance(input, (np.ndarray, Image.Image)):
return True
elif isinstance(input, str):
if not os.path.isfile(input):
raise ValueError('input must be a file')
img = Image.open(input)
_ = img.size
return True
else:
return False
except:
return False
def run(self):
raise NotImplementedError