|
|
|
@ -111,10 +111,10 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
if mode == 'test': |
|
|
|
|
origin_shape = inputs[2] |
|
|
|
|
if self.status == 'Infer': |
|
|
|
|
label_map_list, score_map_list = self._postprocess( |
|
|
|
|
label_map_list, score_map_list = self.postprocess( |
|
|
|
|
net_out, origin_shape, transforms=inputs[3]) |
|
|
|
|
else: |
|
|
|
|
logit_list = self._postprocess( |
|
|
|
|
logit_list = self.postprocess( |
|
|
|
|
logit, origin_shape, transforms=inputs[3]) |
|
|
|
|
label_map_list = [] |
|
|
|
|
score_map_list = [] |
|
|
|
@ -142,7 +142,7 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
raise ValueError("Expected label.ndim == 4 but got {}".format( |
|
|
|
|
label.ndim)) |
|
|
|
|
origin_shape = [label.shape[-2:]] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
pred = self.postprocess( |
|
|
|
|
pred, origin_shape, transforms=inputs[3])[0] # NCHW |
|
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( |
|
|
|
|
pred, label, self.num_classes) |
|
|
|
@ -553,7 +553,7 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
images = [img_file] |
|
|
|
|
else: |
|
|
|
|
images = img_file |
|
|
|
|
batch_im1, batch_im2, batch_origin_shape = self._preprocess( |
|
|
|
|
batch_im1, batch_im2, batch_origin_shape = self.preprocess( |
|
|
|
|
images, transforms, self.model_type) |
|
|
|
|
self.net.eval() |
|
|
|
|
data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms) |
|
|
|
@ -664,7 +664,7 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
dst_data = None |
|
|
|
|
print("GeoTiff saved in {}.".format(save_file)) |
|
|
|
|
|
|
|
|
|
def _preprocess(self, images, transforms, to_tensor=True): |
|
|
|
|
def preprocess(self, images, transforms, to_tensor=True): |
|
|
|
|
self._check_transforms(transforms, 'test') |
|
|
|
|
batch_im1, batch_im2 = list(), list() |
|
|
|
|
batch_ori_shape = list() |
|
|
|
@ -736,7 +736,7 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
batch_restore_list.append(restore_list) |
|
|
|
|
return batch_restore_list |
|
|
|
|
|
|
|
|
|
def _postprocess(self, batch_pred, batch_origin_shape, transforms): |
|
|
|
|
def postprocess(self, batch_pred, batch_origin_shape, transforms): |
|
|
|
|
batch_restore_list = BaseChangeDetector.get_transforms_shape_info( |
|
|
|
|
batch_origin_shape, transforms) |
|
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer': |
|
|
|
|