You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.7 KiB
147 lines
5.7 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import os |
|
import math |
|
|
|
import cv2 |
|
import numpy as np |
|
import paddle |
|
|
|
from paddlers.models.ppseg import utils |
|
from paddlers.models.ppseg.core import infer |
|
from paddlers.models.ppseg.utils import logger, progbar, visualize |
|
|
|
|
|
def mkdir(path): |
|
sub_dir = os.path.dirname(path) |
|
if not os.path.exists(sub_dir): |
|
os.makedirs(sub_dir) |
|
|
|
|
|
def partition_list(arr, m): |
|
"""split the list 'arr' into m pieces""" |
|
n = int(math.ceil(len(arr) / float(m))) |
|
return [arr[i:i + n] for i in range(0, len(arr), n)] |
|
|
|
|
|
def preprocess(im_path, transforms): |
|
data = {} |
|
data['img'] = im_path |
|
data = transforms(data) |
|
data['img'] = data['img'][np.newaxis, ...] |
|
data['img'] = paddle.to_tensor(data['img']) |
|
return data |
|
|
|
|
|
def predict(model, |
|
model_path, |
|
transforms, |
|
image_list, |
|
image_dir=None, |
|
save_dir='output', |
|
aug_pred=False, |
|
scales=1.0, |
|
flip_horizontal=True, |
|
flip_vertical=False, |
|
is_slide=False, |
|
stride=None, |
|
crop_size=None, |
|
custom_color=None): |
|
""" |
|
predict and visualize the image_list. |
|
|
|
Args: |
|
model (nn.Layer): Used to predict for input image. |
|
model_path (str): The path of pretrained model. |
|
transforms (transform.Compose): Preprocess for input image. |
|
image_list (list): A list of image path to be predicted. |
|
image_dir (str, optional): The root directory of the images predicted. Default: None. |
|
save_dir (str, optional): The directory to save the visualized results. Default: 'output'. |
|
aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False. |
|
scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0. |
|
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True. |
|
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False. |
|
is_slide (bool, optional): Whether to predict by sliding window. Default: False. |
|
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height. |
|
It should be provided when `is_slide` is True. |
|
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height. |
|
It should be provided when `is_slide` is True. |
|
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map. |
|
|
|
""" |
|
utils.utils.load_entire_model(model, model_path) |
|
model.eval() |
|
nranks = paddle.distributed.get_world_size() |
|
local_rank = paddle.distributed.get_rank() |
|
if nranks > 1: |
|
img_lists = partition_list(image_list, nranks) |
|
else: |
|
img_lists = [image_list] |
|
|
|
added_saved_dir = os.path.join(save_dir, 'added_prediction') |
|
pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction') |
|
|
|
logger.info("Start to predict...") |
|
progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1) |
|
color_map = visualize.get_color_map_list(256, custom_color=custom_color) |
|
with paddle.no_grad(): |
|
for i, im_path in enumerate(img_lists[local_rank]): |
|
data = preprocess(im_path, transforms) |
|
|
|
if aug_pred: |
|
pred, _ = infer.aug_inference( |
|
model, |
|
data['img'], |
|
trans_info=data['trans_info'], |
|
scales=scales, |
|
flip_horizontal=flip_horizontal, |
|
flip_vertical=flip_vertical, |
|
is_slide=is_slide, |
|
stride=stride, |
|
crop_size=crop_size) |
|
else: |
|
pred, _ = infer.inference( |
|
model, |
|
data['img'], |
|
trans_info=data['trans_info'], |
|
is_slide=is_slide, |
|
stride=stride, |
|
crop_size=crop_size) |
|
pred = paddle.squeeze(pred) |
|
pred = pred.numpy().astype('uint8') |
|
|
|
# get the saved name |
|
if image_dir is not None: |
|
im_file = im_path.replace(image_dir, '') |
|
else: |
|
im_file = os.path.basename(im_path) |
|
if im_file[0] == '/' or im_file[0] == '\\': |
|
im_file = im_file[1:] |
|
|
|
# save added image |
|
added_image = utils.visualize.visualize( |
|
im_path, pred, color_map, weight=0.6) |
|
added_image_path = os.path.join(added_saved_dir, im_file) |
|
mkdir(added_image_path) |
|
cv2.imwrite(added_image_path, added_image) |
|
|
|
# save pseudo color prediction |
|
pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map) |
|
pred_saved_path = os.path.join( |
|
pred_saved_dir, os.path.splitext(im_file)[0] + ".png") |
|
mkdir(pred_saved_path) |
|
pred_mask.save(pred_saved_path) |
|
|
|
progbar_pred.update(i + 1)
|
|
|