You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
6.4 KiB
175 lines
6.4 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import contextlib |
|
import filelock |
|
import os |
|
import tempfile |
|
import numpy as np |
|
import random |
|
from urllib.parse import urlparse, unquote |
|
|
|
import paddle |
|
|
|
from paddlers.models.ppseg.utils import logger, seg_env |
|
from paddlers.models.ppseg.utils.download import download_file_and_uncompress |
|
|
|
|
|
@contextlib.contextmanager |
|
def generate_tempdir(directory: str = None, **kwargs): |
|
'''Generate a temporary directory''' |
|
directory = seg_env.TMP_HOME if not directory else directory |
|
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir: |
|
yield _dir |
|
|
|
|
|
def load_entire_model(model, pretrained): |
|
if pretrained is not None: |
|
load_pretrained_model(model, pretrained) |
|
else: |
|
logger.warning('Not all pretrained params of {} are loaded, ' \ |
|
'training from scratch or a pretrained backbone.'.format(model.__class__.__name__)) |
|
|
|
|
|
def download_pretrained_model(pretrained_model): |
|
""" |
|
Download pretrained model from url. |
|
Args: |
|
pretrained_model (str): the url of pretrained weight |
|
Returns: |
|
str: the path of pretrained weight |
|
""" |
|
assert urlparse(pretrained_model).netloc, "The url is not valid." |
|
|
|
pretrained_model = unquote(pretrained_model) |
|
savename = pretrained_model.split('/')[-1] |
|
if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')): |
|
savename = pretrained_model.split('/')[-2] |
|
else: |
|
savename = savename.split('.')[0] |
|
|
|
with generate_tempdir() as _dir: |
|
with filelock.FileLock(os.path.join(seg_env.TMP_HOME, savename)): |
|
pretrained_model = download_file_and_uncompress( |
|
pretrained_model, |
|
savepath=_dir, |
|
extrapath=seg_env.PRETRAINED_MODEL_HOME, |
|
extraname=savename) |
|
pretrained_model = os.path.join(pretrained_model, 'model.pdparams') |
|
return pretrained_model |
|
|
|
|
|
def load_pretrained_model(model, pretrained_model): |
|
if pretrained_model is not None: |
|
logger.info('Loading pretrained model from {}'.format(pretrained_model)) |
|
|
|
if urlparse(pretrained_model).netloc: |
|
pretrained_model = download_pretrained_model(pretrained_model) |
|
|
|
if os.path.exists(pretrained_model): |
|
para_state_dict = paddle.load(pretrained_model) |
|
|
|
model_state_dict = model.state_dict() |
|
keys = model_state_dict.keys() |
|
num_params_loaded = 0 |
|
for k in keys: |
|
if k not in para_state_dict: |
|
logger.warning("{} is not in pretrained model".format(k)) |
|
elif list(para_state_dict[k].shape) != list( |
|
model_state_dict[k].shape): |
|
logger.warning( |
|
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})" |
|
.format(k, para_state_dict[k].shape, |
|
model_state_dict[k].shape)) |
|
else: |
|
model_state_dict[k] = para_state_dict[k] |
|
num_params_loaded += 1 |
|
model.set_dict(model_state_dict) |
|
logger.info("There are {}/{} variables loaded into {}.".format( |
|
num_params_loaded, len(model_state_dict), |
|
model.__class__.__name__)) |
|
|
|
else: |
|
raise ValueError( |
|
'The pretrained model directory is not Found: {}'.format( |
|
pretrained_model)) |
|
else: |
|
logger.info( |
|
'No pretrained model to load, {} will be trained from scratch.'. |
|
format(model.__class__.__name__)) |
|
|
|
|
|
def resume(model, optimizer, resume_model): |
|
if resume_model is not None: |
|
logger.info('Resume model from {}'.format(resume_model)) |
|
if os.path.exists(resume_model): |
|
resume_model = os.path.normpath(resume_model) |
|
ckpt_path = os.path.join(resume_model, 'model.pdparams') |
|
para_state_dict = paddle.load(ckpt_path) |
|
ckpt_path = os.path.join(resume_model, 'model.pdopt') |
|
opti_state_dict = paddle.load(ckpt_path) |
|
model.set_state_dict(para_state_dict) |
|
optimizer.set_state_dict(opti_state_dict) |
|
|
|
iter = resume_model.split('_')[-1] |
|
iter = int(iter) |
|
return iter |
|
else: |
|
raise ValueError( |
|
'Directory of the model needed to resume is not Found: {}'. |
|
format(resume_model)) |
|
else: |
|
logger.info('No model needed to resume.') |
|
|
|
|
|
def worker_init_fn(worker_id): |
|
np.random.seed(random.randint(0, 100000)) |
|
|
|
|
|
def get_image_list(image_path): |
|
"""Get image list""" |
|
valid_suffix = [ |
|
'.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' |
|
] |
|
image_list = [] |
|
image_dir = None |
|
if os.path.isfile(image_path): |
|
if os.path.splitext(image_path)[-1] in valid_suffix: |
|
image_list.append(image_path) |
|
else: |
|
image_dir = os.path.dirname(image_path) |
|
with open(image_path, 'r') as f: |
|
for line in f: |
|
line = line.strip() |
|
if len(line.split()) > 1: |
|
line = line.split()[0] |
|
image_list.append(os.path.join(image_dir, line)) |
|
elif os.path.isdir(image_path): |
|
image_dir = image_path |
|
for root, dirs, files in os.walk(image_path): |
|
for f in files: |
|
if '.ipynb_checkpoints' in root: |
|
continue |
|
if os.path.splitext(f)[-1] in valid_suffix: |
|
image_list.append(os.path.join(root, f)) |
|
else: |
|
raise FileNotFoundError( |
|
'`--image_path` is not found. it should be a path of image, or a file list containing image paths, or a directory including images.' |
|
) |
|
|
|
if len(image_list) == 0: |
|
raise RuntimeError( |
|
'There are not image file in `--image_path`={}'.format(image_path)) |
|
|
|
return image_list, image_dir
|
|
|