|
|
|
@ -103,11 +103,11 @@ class Predictor(object): |
|
|
|
|
config.enable_use_gpu(200, gpu_id) |
|
|
|
|
config.switch_ir_optim(True) |
|
|
|
|
if use_trt: |
|
|
|
|
if self._model.model_type == 'segmenter': |
|
|
|
|
if self.model_type == 'segmenter': |
|
|
|
|
logging.warning( |
|
|
|
|
"Semantic segmentation models do not support TensorRT acceleration, " |
|
|
|
|
"TensorRT is forcibly disabled.") |
|
|
|
|
elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: |
|
|
|
|
elif self.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: |
|
|
|
|
logging.warning( |
|
|
|
|
"RCNN models do not support TensorRT acceleration, " |
|
|
|
|
"TensorRT is forcibly disabled.") |
|
|
|
@ -150,30 +150,29 @@ class Predictor(object): |
|
|
|
|
def preprocess(self, images, transforms): |
|
|
|
|
preprocessed_samples = self._model.preprocess( |
|
|
|
|
images, transforms, to_tensor=False) |
|
|
|
|
if self._model.model_type == 'classifier': |
|
|
|
|
if self.model_type == 'classifier': |
|
|
|
|
preprocessed_samples = {'image': preprocessed_samples[0]} |
|
|
|
|
elif self._model.model_type == 'segmenter': |
|
|
|
|
elif self.model_type == 'segmenter': |
|
|
|
|
preprocessed_samples = { |
|
|
|
|
'image': preprocessed_samples[0], |
|
|
|
|
'ori_shape': preprocessed_samples[1] |
|
|
|
|
} |
|
|
|
|
elif self._model.model_type == 'detector': |
|
|
|
|
elif self.model_type == 'detector': |
|
|
|
|
pass |
|
|
|
|
elif self._model.model_type == 'change_detector': |
|
|
|
|
elif self.model_type == 'change_detector': |
|
|
|
|
preprocessed_samples = { |
|
|
|
|
'image': preprocessed_samples[0], |
|
|
|
|
'image2': preprocessed_samples[1], |
|
|
|
|
'ori_shape': preprocessed_samples[2] |
|
|
|
|
} |
|
|
|
|
elif self._model.model_type == 'restorer': |
|
|
|
|
elif self.model_type == 'restorer': |
|
|
|
|
preprocessed_samples = { |
|
|
|
|
'image': preprocessed_samples[0], |
|
|
|
|
'tar_shape': preprocessed_samples[1] |
|
|
|
|
} |
|
|
|
|
else: |
|
|
|
|
logging.error( |
|
|
|
|
"Invalid model type {}".format(self._model.model_type), |
|
|
|
|
exit=True) |
|
|
|
|
"Invalid model type {}".format(self.model_type), exit=True) |
|
|
|
|
return preprocessed_samples |
|
|
|
|
|
|
|
|
|
def postprocess(self, |
|
|
|
@ -182,7 +181,7 @@ class Predictor(object): |
|
|
|
|
ori_shape=None, |
|
|
|
|
tar_shape=None, |
|
|
|
|
transforms=None): |
|
|
|
|
if self._model.model_type == 'classifier': |
|
|
|
|
if self.model_type == 'classifier': |
|
|
|
|
true_topk = min(self._model.num_classes, topk) |
|
|
|
|
if self._model.postprocess is None: |
|
|
|
|
self._model.build_postprocess_from_labels(topk) |
|
|
|
@ -198,7 +197,7 @@ class Predictor(object): |
|
|
|
|
'scores_map': s, |
|
|
|
|
'label_names_map': n, |
|
|
|
|
} for l, s, n in zip(class_ids, scores, label_names)] |
|
|
|
|
elif self._model.model_type in ('segmenter', 'change_detector'): |
|
|
|
|
elif self.model_type in ('segmenter', 'change_detector'): |
|
|
|
|
label_map, score_map = self._model.postprocess( |
|
|
|
|
net_outputs, |
|
|
|
|
batch_origin_shape=ori_shape, |
|
|
|
@ -207,13 +206,13 @@ class Predictor(object): |
|
|
|
|
'label_map': l, |
|
|
|
|
'score_map': s |
|
|
|
|
} for l, s in zip(label_map, score_map)] |
|
|
|
|
elif self._model.model_type == 'detector': |
|
|
|
|
elif self.model_type == 'detector': |
|
|
|
|
net_outputs = { |
|
|
|
|
k: v |
|
|
|
|
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs) |
|
|
|
|
} |
|
|
|
|
preds = self._model.postprocess(net_outputs) |
|
|
|
|
elif self._model.model_type == 'restorer': |
|
|
|
|
elif self.model_type == 'restorer': |
|
|
|
|
res_maps = self._model.postprocess( |
|
|
|
|
net_outputs[0], |
|
|
|
|
batch_tar_shape=tar_shape, |
|
|
|
@ -221,8 +220,7 @@ class Predictor(object): |
|
|
|
|
preds = [{'res_map': res_map} for res_map in res_maps] |
|
|
|
|
else: |
|
|
|
|
logging.error( |
|
|
|
|
"Invalid model type {}.".format(self._model.model_type), |
|
|
|
|
exit=True) |
|
|
|
|
"Invalid model type {}.".format(self.model_type), exit=True) |
|
|
|
|
|
|
|
|
|
return preds |
|
|
|
|
|
|
|
|
@ -360,6 +358,12 @@ class Predictor(object): |
|
|
|
|
batch_size (int, optional): Batch size used in inference. Defaults to 1. |
|
|
|
|
quiet (bool, optional): If True, disable the progress bar. Defaults to False. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if self.model_type not in ('segmenter', 'change_detector'): |
|
|
|
|
raise RuntimeError( |
|
|
|
|
"Model type is {}, which does not support inference with sliding windows.". |
|
|
|
|
format(self.model_type)) |
|
|
|
|
|
|
|
|
|
slider_predict( |
|
|
|
|
partial( |
|
|
|
|
self.predict, quiet=True), |
|
|
|
@ -375,3 +379,7 @@ class Predictor(object): |
|
|
|
|
|
|
|
|
|
def batch_predict(self, image_list, **params): |
|
|
|
|
return self.predict(img_file=image_list, **params) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def model_type(self): |
|
|
|
|
return self._model.model_type |
|
|
|
|