[Feat] Add predict() method for ChangeDetector (#74)

own
Lin Manhui 3 years ago committed by GitHub
parent 574e5180cc
commit 8024d83606
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 44
      paddlers/tasks/change_detector.py

@ -381,6 +381,12 @@ class BaseChangeDetector(BaseModel):
Returns:
collections.OrderedDict with key-value pairs:
For binary change detection (number of classes == 2), the key-value pairs are like:
{"iou": `intersection over union for the change class`,
"f1": `F1 score for the change class`,
"oacc": `overall accuracy`,
"kappa": ` kappa coefficient`}.
For multi-class change detection (number of classes > 2), the key-value pairs are like:
{"miou": `mean intersection over union`,
"category_iou": `category-wise mean intersection over union`,
"oacc": `overall accuracy`,
@ -408,7 +414,7 @@ class BaseChangeDetector(BaseModel):
batch_size_each_card = 1
batch_size = batch_size_each_card * paddlers.env_info['num']
logging.warning(
"Segmenter only supports batch_size=1 for each gpu/cpu card " \
"ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
"during evaluation, so batch_size " \
"is forcibly set to {}.".format(batch_size)
)
@ -471,11 +477,17 @@ class BaseChangeDetector(BaseModel):
label_area_all)
category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
label_area_all)
eval_metrics = OrderedDict(
zip([
'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
'category_F1-score'
], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
if len(class_acc) > 2:
eval_metrics = OrderedDict(
zip([
'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
'category_F1-score'
], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
else:
eval_metrics = OrderedDict(
zip(['iou', 'f1', 'oacc', 'kappa'],
[class_iou[1], category_f1score[1], oacc, kappa]))
if return_details:
conf_mat = sum(conf_mat_all)
@ -488,14 +500,14 @@ class BaseChangeDetector(BaseModel):
Do inference.
Args:
Args:
img_file(List[np.ndarray or str], str or np.ndarray):
Image path or decoded image data in a BGR format, which also could constitute a list,
meaning all images to be predicted as a mini-batch.
img_file(List[tuple], Tuple[str or np.ndarray]):
Tuple of image paths or decoded image data in a BGR format for bi-temporal images, which also could constitute
a list, meaning all image pairs to be predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
Returns:
If img_file is a string or np.array, the result is a dict with key-value pairs:
If img_file is a tuple of string or np.array, the result is a dict with key-value pairs:
{"label map": `label map`, "score_map": `score map`}.
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
label_map(np.ndarray): the predicted label map (HW)
@ -506,14 +518,18 @@ class BaseChangeDetector(BaseModel):
raise Exception("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, (str, np.ndarray)):
if isinstance(img_file, tuple):
if not len(img_file) == 2 and any(
map(lambda obj: not isinstance(obj, (str, np.ndarray)),
img_file)):
raise TypeError
images = [img_file]
else:
images = img_file
batch_im, batch_origin_shape = self._preprocess(images, transforms,
self.model_type)
batch_im1, batch_im2, batch_origin_shape = self._preprocess(
images, transforms, self.model_type)
self.net.eval()
data = (batch_im, batch_origin_shape, transforms.transforms)
data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
outputs = self.run(self.net, data, 'test')
label_map_list = outputs['label_map']
score_map_list = outputs['score_map']

Loading…
Cancel
Save