diff --git a/paddlers/models/ppcls/utils/model_zoo.py b/paddlers/models/ppcls/utils/model_zoo.py index bad0752..d9436fe 100644 --- a/paddlers/models/ppcls/utils/model_zoo.py +++ b/paddlers/models/ppcls/utils/model_zoo.py @@ -151,26 +151,28 @@ def _decompress(fname): if fname.find('tar') >= 0: with tarfile.open(fname) as tf: - def is_within_directory(directory, target): - + + def _is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) - + prefix = os.path.commonprefix([abs_directory, abs_target]) - + return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - + + def _safe_extract(tar, + path=".", + members=None, + *, + numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): + if not _is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - - safe_extract(tf, path=fpath_tmp) + + tar.extractall(path, members, numeric_owner=numeric_owner) + + _safe_extract(tf, path=fpath_tmp) elif fname.find('zip') >= 0: with zipfile.ZipFile(fname) as zf: zf.extractall(path=fpath_tmp) diff --git a/paddlers/utils/download.py b/paddlers/utils/download.py index 435ad17..56cf30e 100644 --- a/paddlers/utils/download.py +++ b/paddlers/utils/download.py @@ -151,26 +151,28 @@ def decompress(fname): if fname.find('tar') >= 0 or fname.find('tgz') >= 0: with tarfile.open(fname) as tf: - def is_within_directory(directory, target): - + + def _is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) - + prefix = os.path.commonprefix([abs_directory, abs_target]) - + return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - + + def _safe_extract(tar, + path=".", + members=None, + *, + numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): + if not _is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - - safe_extract(tf, path=fpath_tmp) + + tar.extractall(path, members, numeric_owner=numeric_owner) + + _safe_extract(tf, path=fpath_tmp) elif fname.find('zip') >= 0: with zipfile.ZipFile(fname) as zf: zf.extractall(path=fpath_tmp)