You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
6.2 KiB
213 lines
6.2 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import requests |
|
import shutil |
|
import tarfile |
|
import tqdm |
|
import zipfile |
|
|
|
from ppcls.arch import similar_architectures |
|
from ppcls.utils import logger |
|
|
|
__all__ = ['get'] |
|
|
|
DOWNLOAD_RETRY_LIMIT = 3 |
|
|
|
|
|
class UrlError(Exception): |
|
""" UrlError |
|
""" |
|
|
|
def __init__(self, url='', code=''): |
|
message = "Downloading from {} failed with code {}!".format(url, code) |
|
super(UrlError, self).__init__(message) |
|
|
|
|
|
class ModelNameError(Exception): |
|
""" ModelNameError |
|
""" |
|
|
|
def __init__(self, message=''): |
|
super(ModelNameError, self).__init__(message) |
|
|
|
|
|
class RetryError(Exception): |
|
""" RetryError |
|
""" |
|
|
|
def __init__(self, url='', times=''): |
|
message = "Download from {} failed. Retry({}) limit reached".format( |
|
url, times) |
|
super(RetryError, self).__init__(message) |
|
|
|
|
|
def _get_url(architecture, postfix="pdparams"): |
|
prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/" |
|
fname = architecture + "_pretrained." + postfix |
|
return prefix + fname |
|
|
|
|
|
def _move_and_merge_tree(src, dst): |
|
""" |
|
Move src directory to dst, if dst is already exists, |
|
merge src to dst |
|
""" |
|
if not os.path.exists(dst): |
|
shutil.move(src, dst) |
|
elif os.path.isfile(src): |
|
shutil.move(src, dst) |
|
else: |
|
for fp in os.listdir(src): |
|
src_fp = os.path.join(src, fp) |
|
dst_fp = os.path.join(dst, fp) |
|
if os.path.isdir(src_fp): |
|
if os.path.isdir(dst_fp): |
|
_move_and_merge_tree(src_fp, dst_fp) |
|
else: |
|
shutil.move(src_fp, dst_fp) |
|
elif os.path.isfile(src_fp) and \ |
|
not os.path.isfile(dst_fp): |
|
shutil.move(src_fp, dst_fp) |
|
|
|
|
|
def _download(url, path): |
|
""" |
|
Download from url, save to path. |
|
url (str): download url |
|
path (str): download to given path |
|
""" |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
fname = os.path.split(url)[-1] |
|
fullname = os.path.join(path, fname) |
|
retry_cnt = 0 |
|
|
|
while not os.path.exists(fullname): |
|
if retry_cnt < DOWNLOAD_RETRY_LIMIT: |
|
retry_cnt += 1 |
|
else: |
|
raise RetryError(url, DOWNLOAD_RETRY_LIMIT) |
|
|
|
logger.info("Downloading {} from {}".format(fname, url)) |
|
|
|
req = requests.get(url, stream=True) |
|
if req.status_code != 200: |
|
raise UrlError(url, req.status_code) |
|
|
|
# For protecting download interupted, download to |
|
# tmp_fullname firstly, move tmp_fullname to fullname |
|
# after download finished |
|
tmp_fullname = fullname + "_tmp" |
|
total_size = req.headers.get('content-length') |
|
with open(tmp_fullname, 'wb') as f: |
|
if total_size: |
|
for chunk in tqdm.tqdm( |
|
req.iter_content(chunk_size=1024), |
|
total=(int(total_size) + 1023) // 1024, |
|
unit='KB'): |
|
f.write(chunk) |
|
else: |
|
for chunk in req.iter_content(chunk_size=1024): |
|
if chunk: |
|
f.write(chunk) |
|
shutil.move(tmp_fullname, fullname) |
|
|
|
return fullname |
|
|
|
|
|
def _decompress(fname): |
|
""" |
|
Decompress for zip and tar file |
|
""" |
|
logger.info("Decompressing {}...".format(fname)) |
|
|
|
# For protecting decompressing interupted, |
|
# decompress to fpath_tmp directory firstly, if decompress |
|
# successed, move decompress files to fpath and delete |
|
# fpath_tmp and remove download compress file. |
|
fpath = os.path.split(fname)[0] |
|
fpath_tmp = os.path.join(fpath, 'tmp') |
|
if os.path.isdir(fpath_tmp): |
|
shutil.rmtree(fpath_tmp) |
|
os.makedirs(fpath_tmp) |
|
|
|
if fname.find('tar') >= 0: |
|
with tarfile.open(fname) as tf: |
|
tf.extractall(path=fpath_tmp) |
|
elif fname.find('zip') >= 0: |
|
with zipfile.ZipFile(fname) as zf: |
|
zf.extractall(path=fpath_tmp) |
|
else: |
|
raise TypeError("Unsupport compress file type {}".format(fname)) |
|
|
|
fs = os.listdir(fpath_tmp) |
|
assert len( |
|
fs |
|
) == 1, "There should just be 1 pretrained path in an archive file but got {}.".format( |
|
len(fs)) |
|
|
|
f = fs[0] |
|
src_dir = os.path.join(fpath_tmp, f) |
|
dst_dir = os.path.join(fpath, f) |
|
_move_and_merge_tree(src_dir, dst_dir) |
|
|
|
shutil.rmtree(fpath_tmp) |
|
os.remove(fname) |
|
|
|
return f |
|
|
|
|
|
def _get_pretrained(): |
|
with open('./ppcls/utils/pretrained.list') as flist: |
|
pretrained = [line.strip() for line in flist] |
|
return pretrained |
|
|
|
|
|
def _check_pretrained_name(architecture): |
|
assert isinstance(architecture, str), \ |
|
("the type of architecture({}) should be str". format(architecture)) |
|
pretrained = _get_pretrained() |
|
similar_names = similar_architectures(architecture, pretrained) |
|
model_list = ', '.join(similar_names) |
|
err = "{} is not exist! Maybe you want: [{}]" \ |
|
"".format(architecture, model_list) |
|
if architecture not in similar_names: |
|
raise ModelNameError(err) |
|
|
|
|
|
def list_models(): |
|
pretrained = _get_pretrained() |
|
msg = "All avialable pretrained models are as follows: {}".format( |
|
pretrained) |
|
logger.info(msg) |
|
return |
|
|
|
|
|
def get(architecture, path, decompress=False, postfix="pdparams"): |
|
""" |
|
Get the pretrained model. |
|
""" |
|
_check_pretrained_name(architecture) |
|
url = _get_url(architecture, postfix=postfix) |
|
fname = _download(url, path) |
|
if postfix == "tar" and decompress: |
|
_decompress(fname) |
|
logger.info("download {} finished ".format(fname))
|
|
|