You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
151 lines
5.9 KiB
151 lines
5.9 KiB
3 years ago
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
import math
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
|
||
|
from paddlers.models.ppseg import utils
|
||
|
from paddlers.models.ppseg.core import infer
|
||
|
from paddlers.models.ppseg.utils import logger, progbar, visualize
|
||
|
|
||
|
|
||
|
def mkdir(path):
|
||
|
sub_dir = os.path.dirname(path)
|
||
|
if not os.path.exists(sub_dir):
|
||
|
os.makedirs(sub_dir)
|
||
|
|
||
|
|
||
|
def partition_list(arr, m):
|
||
|
"""split the list 'arr' into m pieces"""
|
||
|
n = int(math.ceil(len(arr) / float(m)))
|
||
|
return [arr[i:i + n] for i in range(0, len(arr), n)]
|
||
|
|
||
|
|
||
|
def predict(model,
|
||
|
model_path,
|
||
|
transforms,
|
||
|
image_list,
|
||
|
image_dir=None,
|
||
|
save_dir='output',
|
||
|
aug_pred=False,
|
||
|
scales=1.0,
|
||
|
flip_horizontal=True,
|
||
|
flip_vertical=False,
|
||
|
is_slide=False,
|
||
|
stride=None,
|
||
|
crop_size=None,
|
||
|
custom_color=None):
|
||
|
"""
|
||
|
predict and visualize the image_list.
|
||
|
|
||
|
Args:
|
||
|
model (nn.Layer): Used to predict for input image.
|
||
|
model_path (str): The path of pretrained model.
|
||
|
transforms (transform.Compose): Preprocess for input image.
|
||
|
image_list (list): A list of image path to be predicted.
|
||
|
image_dir (str, optional): The root directory of the images predicted. Default: None.
|
||
|
save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
|
||
|
aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False.
|
||
|
scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0.
|
||
|
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True.
|
||
|
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False.
|
||
|
is_slide (bool, optional): Whether to predict by sliding window. Default: False.
|
||
|
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
|
||
|
It should be provided when `is_slide` is True.
|
||
|
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
|
||
|
It should be provided when `is_slide` is True.
|
||
|
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map.
|
||
|
|
||
|
"""
|
||
|
utils.utils.load_entire_model(model, model_path)
|
||
|
model.eval()
|
||
|
nranks = paddle.distributed.get_world_size()
|
||
|
local_rank = paddle.distributed.get_rank()
|
||
|
if nranks > 1:
|
||
|
img_lists = partition_list(image_list, nranks)
|
||
|
else:
|
||
|
img_lists = [image_list]
|
||
|
|
||
|
added_saved_dir = os.path.join(save_dir, 'added_prediction')
|
||
|
pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')
|
||
|
|
||
|
logger.info("Start to predict...")
|
||
|
progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
|
||
|
color_map = visualize.get_color_map_list(256, custom_color=custom_color)
|
||
|
with paddle.no_grad():
|
||
|
for i, im_path in enumerate(img_lists[local_rank]):
|
||
|
im = cv2.imread(im_path)
|
||
|
ori_shape = im.shape[:2]
|
||
|
im, _ = transforms(im)
|
||
|
im = im[np.newaxis, ...]
|
||
|
im = paddle.to_tensor(im)
|
||
|
|
||
|
if aug_pred:
|
||
|
pred, _ = infer.aug_inference(
|
||
|
model,
|
||
|
im,
|
||
|
ori_shape=ori_shape,
|
||
|
transforms=transforms.transforms,
|
||
|
scales=scales,
|
||
|
flip_horizontal=flip_horizontal,
|
||
|
flip_vertical=flip_vertical,
|
||
|
is_slide=is_slide,
|
||
|
stride=stride,
|
||
|
crop_size=crop_size)
|
||
|
else:
|
||
|
pred, _ = infer.inference(
|
||
|
model,
|
||
|
im,
|
||
|
ori_shape=ori_shape,
|
||
|
transforms=transforms.transforms,
|
||
|
is_slide=is_slide,
|
||
|
stride=stride,
|
||
|
crop_size=crop_size)
|
||
|
pred = paddle.squeeze(pred)
|
||
|
pred = pred.numpy().astype('uint8')
|
||
|
|
||
|
# get the saved name
|
||
|
if image_dir is not None:
|
||
|
im_file = im_path.replace(image_dir, '')
|
||
|
else:
|
||
|
im_file = os.path.basename(im_path)
|
||
|
if im_file[0] == '/' or im_file[0] == '\\':
|
||
|
im_file = im_file[1:]
|
||
|
|
||
|
# save added image
|
||
|
added_image = utils.visualize.visualize(
|
||
|
im_path, pred, color_map, weight=0.6)
|
||
|
added_image_path = os.path.join(added_saved_dir, im_file)
|
||
|
mkdir(added_image_path)
|
||
|
cv2.imwrite(added_image_path, added_image)
|
||
|
|
||
|
# save pseudo color prediction
|
||
|
pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map)
|
||
|
pred_saved_path = os.path.join(
|
||
|
pred_saved_dir,
|
||
|
os.path.splitext(im_file)[0] + ".png")
|
||
|
mkdir(pred_saved_path)
|
||
|
pred_mask.save(pred_saved_path)
|
||
|
|
||
|
# pred_im = utils.visualize(im_path, pred, weight=0.0)
|
||
|
# pred_saved_path = os.path.join(pred_saved_dir, im_file)
|
||
|
# mkdir(pred_saved_path)
|
||
|
# cv2.imwrite(pred_saved_path, pred_im)
|
||
|
|
||
|
progbar_pred.update(i + 1)
|