Add model type check

own
Bobholamovic 2 years ago
parent 9fd3b7b00e
commit 9c1b2ea2fe
  1. 38
      paddlers/deploy/predictor.py

@ -103,11 +103,11 @@ class Predictor(object):
config.enable_use_gpu(200, gpu_id)
config.switch_ir_optim(True)
if use_trt:
if self._model.model_type == 'segmenter':
if self.model_type == 'segmenter':
logging.warning(
"Semantic segmentation models do not support TensorRT acceleration, "
"TensorRT is forcibly disabled.")
elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
elif self.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
logging.warning(
"RCNN models do not support TensorRT acceleration, "
"TensorRT is forcibly disabled.")
@ -150,30 +150,29 @@ class Predictor(object):
def preprocess(self, images, transforms):
preprocessed_samples = self._model.preprocess(
images, transforms, to_tensor=False)
if self._model.model_type == 'classifier':
if self.model_type == 'classifier':
preprocessed_samples = {'image': preprocessed_samples[0]}
elif self._model.model_type == 'segmenter':
elif self.model_type == 'segmenter':
preprocessed_samples = {
'image': preprocessed_samples[0],
'ori_shape': preprocessed_samples[1]
}
elif self._model.model_type == 'detector':
elif self.model_type == 'detector':
pass
elif self._model.model_type == 'change_detector':
elif self.model_type == 'change_detector':
preprocessed_samples = {
'image': preprocessed_samples[0],
'image2': preprocessed_samples[1],
'ori_shape': preprocessed_samples[2]
}
elif self._model.model_type == 'restorer':
elif self.model_type == 'restorer':
preprocessed_samples = {
'image': preprocessed_samples[0],
'tar_shape': preprocessed_samples[1]
}
else:
logging.error(
"Invalid model type {}".format(self._model.model_type),
exit=True)
"Invalid model type {}".format(self.model_type), exit=True)
return preprocessed_samples
def postprocess(self,
@ -182,7 +181,7 @@ class Predictor(object):
ori_shape=None,
tar_shape=None,
transforms=None):
if self._model.model_type == 'classifier':
if self.model_type == 'classifier':
true_topk = min(self._model.num_classes, topk)
if self._model.postprocess is None:
self._model.build_postprocess_from_labels(topk)
@ -198,7 +197,7 @@ class Predictor(object):
'scores_map': s,
'label_names_map': n,
} for l, s, n in zip(class_ids, scores, label_names)]
elif self._model.model_type in ('segmenter', 'change_detector'):
elif self.model_type in ('segmenter', 'change_detector'):
label_map, score_map = self._model.postprocess(
net_outputs,
batch_origin_shape=ori_shape,
@ -207,13 +206,13 @@ class Predictor(object):
'label_map': l,
'score_map': s
} for l, s in zip(label_map, score_map)]
elif self._model.model_type == 'detector':
elif self.model_type == 'detector':
net_outputs = {
k: v
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
}
preds = self._model.postprocess(net_outputs)
elif self._model.model_type == 'restorer':
elif self.model_type == 'restorer':
res_maps = self._model.postprocess(
net_outputs[0],
batch_tar_shape=tar_shape,
@ -221,8 +220,7 @@ class Predictor(object):
preds = [{'res_map': res_map} for res_map in res_maps]
else:
logging.error(
"Invalid model type {}.".format(self._model.model_type),
exit=True)
"Invalid model type {}.".format(self.model_type), exit=True)
return preds
@ -360,6 +358,12 @@ class Predictor(object):
batch_size (int, optional): Batch size used in inference. Defaults to 1.
quiet (bool, optional): If True, disable the progress bar. Defaults to False.
"""
if self.model_type not in ('segmenter', 'change_detector'):
raise RuntimeError(
"Model type is {}, which does not support inference with sliding windows.".
format(self.model_type))
slider_predict(
partial(
self.predict, quiet=True),
@ -375,3 +379,7 @@ class Predictor(object):
def batch_predict(self, image_list, **params):
return self.predict(img_file=image_list, **params)
@property
def model_type(self):
return self._model.model_type

Loading…
Cancel
Save