mirror of https://github.com/opencv/opencv.git
commit
dd1494eebf
24 changed files with 985 additions and 226 deletions
@ -0,0 +1,3 @@ |
|||||||
|
*.caffemodel |
||||||
|
*.pb |
||||||
|
*.weights |
@ -0,0 +1,364 @@ |
|||||||
|
''' |
||||||
|
Helper module to download extra data from Internet |
||||||
|
''' |
||||||
|
from __future__ import print_function |
||||||
|
import os |
||||||
|
import cv2 |
||||||
|
import sys |
||||||
|
import yaml |
||||||
|
import argparse |
||||||
|
import tarfile |
||||||
|
import platform |
||||||
|
import tempfile |
||||||
|
import hashlib |
||||||
|
import requests |
||||||
|
import shutil |
||||||
|
from pathlib import Path |
||||||
|
from datetime import datetime |
||||||
|
if sys.version_info[0] < 3: |
||||||
|
from urllib2 import urlopen |
||||||
|
else: |
||||||
|
from urllib.request import urlopen |
||||||
|
import xml.etree.ElementTree as ET |
||||||
|
|
||||||
|
__all__ = ["downloadFile"] |
||||||
|
|
||||||
|
class HashMismatchException(Exception): |
||||||
|
def __init__(self, expected, actual): |
||||||
|
Exception.__init__(self) |
||||||
|
self.expected = expected |
||||||
|
self.actual = actual |
||||||
|
def __str__(self): |
||||||
|
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual) |
||||||
|
|
||||||
|
def getHashsumFromFile(filepath): |
||||||
|
sha = hashlib.sha1() |
||||||
|
if os.path.exists(filepath): |
||||||
|
print(' there is already a file with the same name') |
||||||
|
with open(filepath, 'rb') as f: |
||||||
|
while True: |
||||||
|
buf = f.read(10*1024*1024) |
||||||
|
if not buf: |
||||||
|
break |
||||||
|
sha.update(buf) |
||||||
|
hashsum = sha.hexdigest() |
||||||
|
return hashsum |
||||||
|
|
||||||
|
def checkHashsum(expected_sha, filepath, silent=True): |
||||||
|
print(' expected SHA1: {}'.format(expected_sha)) |
||||||
|
actual_sha = getHashsumFromFile(filepath) |
||||||
|
print(' actual SHA1:{}'.format(actual_sha)) |
||||||
|
hashes_matched = expected_sha == actual_sha |
||||||
|
if not hashes_matched and not silent: |
||||||
|
raise HashMismatchException(expected_sha, actual_sha) |
||||||
|
return hashes_matched |
||||||
|
|
||||||
|
def isArchive(filepath): |
||||||
|
return tarfile.is_tarfile(filepath) |
||||||
|
|
||||||
|
class DownloadInstance: |
||||||
|
def __init__(self, **kwargs): |
||||||
|
self.name = kwargs.pop('name') |
||||||
|
self.filename = kwargs.pop('filename') |
||||||
|
self.loader = kwargs.pop('loader', None) |
||||||
|
self.save_dir = kwargs.pop('save_dir') |
||||||
|
self.sha = kwargs.pop('sha', None) |
||||||
|
|
||||||
|
def __str__(self): |
||||||
|
return 'DownloadInstance <{}>'.format(self.name) |
||||||
|
|
||||||
|
def get(self): |
||||||
|
print(" Working on " + self.name) |
||||||
|
print(" Getting file " + self.filename) |
||||||
|
if self.sha is None: |
||||||
|
print(' No expected hashsum provided, loading file') |
||||||
|
else: |
||||||
|
filepath = os.path.join(self.save_dir, self.sha, self.filename) |
||||||
|
if checkHashsum(self.sha, filepath): |
||||||
|
print(' hash match - file already exists, skipping') |
||||||
|
return filepath |
||||||
|
else: |
||||||
|
print(' hash didn\'t match, loading file') |
||||||
|
|
||||||
|
if not os.path.exists(self.save_dir): |
||||||
|
print(' creating directory: ' + self.save_dir) |
||||||
|
os.makedirs(self.save_dir) |
||||||
|
|
||||||
|
|
||||||
|
print(' hash check failed - loading') |
||||||
|
assert self.loader |
||||||
|
try: |
||||||
|
self.loader.load(self.filename, self.sha, self.save_dir) |
||||||
|
print(' done') |
||||||
|
print(' file {}'.format(self.filename)) |
||||||
|
if self.sha is None: |
||||||
|
download_path = os.path.join(self.save_dir, self.filename) |
||||||
|
self.sha = getHashsumFromFile(download_path) |
||||||
|
new_dir = os.path.join(self.save_dir, self.sha) |
||||||
|
|
||||||
|
if not os.path.exists(new_dir): |
||||||
|
os.makedirs(new_dir) |
||||||
|
filepath = os.path.join(new_dir, self.filename) |
||||||
|
if not (os.path.exists(filepath)): |
||||||
|
shutil.move(download_path, new_dir) |
||||||
|
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha)) |
||||||
|
else: |
||||||
|
checkHashsum(self.sha, filepath, silent=False) |
||||||
|
except Exception as e: |
||||||
|
print(" There was some problem with loading file {} for {}".format(self.filename, self.name)) |
||||||
|
print(" Exception: {}".format(e)) |
||||||
|
return |
||||||
|
|
||||||
|
print(" Finished " + self.name) |
||||||
|
return filepath |
||||||
|
|
||||||
|
class Loader(object): |
||||||
|
MB = 1024*1024 |
||||||
|
BUFSIZE = 10*MB |
||||||
|
def __init__(self, download_name, download_sha, archive_member = None): |
||||||
|
self.download_name = download_name |
||||||
|
self.download_sha = download_sha |
||||||
|
self.archive_member = archive_member |
||||||
|
|
||||||
|
def load(self, requested_file, sha, save_dir): |
||||||
|
if self.download_sha is None: |
||||||
|
download_dir = save_dir |
||||||
|
else: |
||||||
|
# create a new folder in save_dir to avoid possible name conflicts |
||||||
|
download_dir = os.path.join(save_dir, self.download_sha) |
||||||
|
if not os.path.exists(download_dir): |
||||||
|
os.makedirs(download_dir) |
||||||
|
download_path = os.path.join(download_dir, self.download_name) |
||||||
|
print(" Preparing to download file " + self.download_name) |
||||||
|
if checkHashsum(self.download_sha, download_path): |
||||||
|
print(' hash match - file already exists, no need to download') |
||||||
|
else: |
||||||
|
filesize = self.download(download_path) |
||||||
|
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB)) |
||||||
|
if self.download_sha is not None: |
||||||
|
checkHashsum(self.download_sha, download_path, silent=False) |
||||||
|
if self.download_name == requested_file: |
||||||
|
return |
||||||
|
else: |
||||||
|
if isArchive(download_path): |
||||||
|
if sha is not None: |
||||||
|
extract_dir = os.path.join(save_dir, sha) |
||||||
|
else: |
||||||
|
extract_dir = save_dir |
||||||
|
if not os.path.exists(extract_dir): |
||||||
|
os.makedirs(extract_dir) |
||||||
|
self.extract(requested_file, download_path, extract_dir) |
||||||
|
else: |
||||||
|
raise Exception("Downloaded file has different name") |
||||||
|
|
||||||
|
def download(self, filepath): |
||||||
|
print("Warning: download is not implemented, this is a base class") |
||||||
|
return 0 |
||||||
|
|
||||||
|
def extract(self, requested_file, archive_path, save_dir): |
||||||
|
filepath = os.path.join(save_dir, requested_file) |
||||||
|
try: |
||||||
|
with tarfile.open(archive_path) as f: |
||||||
|
if self.archive_member is None: |
||||||
|
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames()) |
||||||
|
self.archive_member = pathDict[requested_file] |
||||||
|
assert self.archive_member in f.getnames() |
||||||
|
self.save(filepath, f.extractfile(self.archive_member)) |
||||||
|
except Exception as e: |
||||||
|
print(' catch {}'.format(e)) |
||||||
|
|
||||||
|
def save(self, filepath, r): |
||||||
|
with open(filepath, 'wb') as f: |
||||||
|
print(' progress ', end="") |
||||||
|
sys.stdout.flush() |
||||||
|
while True: |
||||||
|
buf = r.read(self.BUFSIZE) |
||||||
|
if not buf: |
||||||
|
break |
||||||
|
f.write(buf) |
||||||
|
print('>', end="") |
||||||
|
sys.stdout.flush() |
||||||
|
|
||||||
|
class URLLoader(Loader): |
||||||
|
def __init__(self, download_name, download_sha, url, archive_member = None): |
||||||
|
super(URLLoader, self).__init__(download_name, download_sha, archive_member) |
||||||
|
self.download_name = download_name |
||||||
|
self.download_sha = download_sha |
||||||
|
self.url = url |
||||||
|
|
||||||
|
def download(self, filepath): |
||||||
|
r = urlopen(self.url, timeout=60) |
||||||
|
self.printRequest(r) |
||||||
|
self.save(filepath, r) |
||||||
|
return os.path.getsize(filepath) |
||||||
|
|
||||||
|
def printRequest(self, r): |
||||||
|
def getMB(r): |
||||||
|
d = dict(r.info()) |
||||||
|
for c in ['content-length', 'Content-Length']: |
||||||
|
if c in d: |
||||||
|
return int(d[c]) / self.MB |
||||||
|
return '<unknown>' |
||||||
|
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r))) |
||||||
|
|
||||||
|
class GDriveLoader(Loader): |
||||||
|
BUFSIZE = 1024 * 1024 |
||||||
|
PROGRESS_SIZE = 10 * 1024 * 1024 |
||||||
|
def __init__(self, download_name, download_sha, gid, archive_member = None): |
||||||
|
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member) |
||||||
|
self.download_name = download_name |
||||||
|
self.download_sha = download_sha |
||||||
|
self.gid = gid |
||||||
|
|
||||||
|
def download(self, filepath): |
||||||
|
session = requests.Session() # re-use cookies |
||||||
|
|
||||||
|
URL = "https://docs.google.com/uc?export=download" |
||||||
|
response = session.get(URL, params = { 'id' : self.gid }, stream = True) |
||||||
|
|
||||||
|
def get_confirm_token(response): # in case of large files |
||||||
|
for key, value in response.cookies.items(): |
||||||
|
if key.startswith('download_warning'): |
||||||
|
return value |
||||||
|
return None |
||||||
|
token = get_confirm_token(response) |
||||||
|
|
||||||
|
if token: |
||||||
|
params = { 'id' : self.gid, 'confirm' : token } |
||||||
|
response = session.get(URL, params = params, stream = True) |
||||||
|
|
||||||
|
sz = 0 |
||||||
|
progress_sz = self.PROGRESS_SIZE |
||||||
|
with open(filepath, "wb") as f: |
||||||
|
for chunk in response.iter_content(self.BUFSIZE): |
||||||
|
if not chunk: |
||||||
|
continue # keep-alive |
||||||
|
|
||||||
|
f.write(chunk) |
||||||
|
sz += len(chunk) |
||||||
|
if sz >= progress_sz: |
||||||
|
progress_sz += self.PROGRESS_SIZE |
||||||
|
print('>', end='') |
||||||
|
sys.stdout.flush() |
||||||
|
print('') |
||||||
|
return sz |
||||||
|
|
||||||
|
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None): |
||||||
|
spec_param = url |
||||||
|
loader = URLLoader |
||||||
|
if download_name is None: |
||||||
|
download_name = filename |
||||||
|
if download_sha is None: |
||||||
|
download_sha = sha |
||||||
|
if "drive.google.com" in url: |
||||||
|
token = "" |
||||||
|
token_part = url.rsplit('/', 1)[-1] |
||||||
|
if "&id=" not in token_part: |
||||||
|
token_part = url.rsplit('/', 1)[-2] |
||||||
|
for param in token_part.split("&"): |
||||||
|
if param.startswith("id="): |
||||||
|
token = param[3:] |
||||||
|
if token: |
||||||
|
loader = GDriveLoader |
||||||
|
spec_param = token |
||||||
|
else: |
||||||
|
print("Warning: possibly wrong Google Drive link") |
||||||
|
return DownloadInstance( |
||||||
|
name=instance_name, |
||||||
|
filename=filename, |
||||||
|
sha=sha, |
||||||
|
save_dir=save_dir, |
||||||
|
loader=loader(download_name, download_sha, spec_param, archive_member) |
||||||
|
) |
||||||
|
|
||||||
|
def getSaveDir(): |
||||||
|
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None) |
||||||
|
if env_path: |
||||||
|
save_dir = env_path |
||||||
|
else: |
||||||
|
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed |
||||||
|
if platform.system() == "Darwin": |
||||||
|
#On Apple devices |
||||||
|
temp_env = os.environ.get("TMPDIR", None) |
||||||
|
if temp_env is None or not os.path.isdir(temp_env): |
||||||
|
temp_dir = Path("/tmp") |
||||||
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir) |
||||||
|
else: |
||||||
|
temp_dir = temp_env |
||||||
|
elif platform.system() == "Windows": |
||||||
|
temp_dir = tempfile.gettempdir() |
||||||
|
else: |
||||||
|
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None) |
||||||
|
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)): |
||||||
|
temp_dir = xdg_cache_env |
||||||
|
else: |
||||||
|
home_env = os.environ.get("HOME", None) |
||||||
|
if (home_env and home_env[0] and os.path.isdir(home_env)): |
||||||
|
home_path = os.path.join(home_env, ".cache/") |
||||||
|
if os.path.isdir(home_path): |
||||||
|
temp_dir = home_path |
||||||
|
else: |
||||||
|
temp_dir = tempfile.gettempdir() |
||||||
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir) |
||||||
|
|
||||||
|
save_dir = os.path.join(temp_dir, "downloads") |
||||||
|
if not os.path.exists(save_dir): |
||||||
|
os.makedirs(save_dir) |
||||||
|
return save_dir |
||||||
|
|
||||||
|
def downloadFile(url, sha=None, save_dir=None, filename=None): |
||||||
|
if save_dir is None: |
||||||
|
save_dir = getSaveDir() |
||||||
|
if filename is None: |
||||||
|
filename = "download_" + datetime.now().__str__() |
||||||
|
name = filename |
||||||
|
return produceDownloadInstance(name, filename, sha, url, save_dir).get() |
||||||
|
|
||||||
|
def parseMetalinkFile(metalink_filepath, save_dir): |
||||||
|
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'} |
||||||
|
models = [] |
||||||
|
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS): |
||||||
|
url = file_elem.find('ml:url', NS).text |
||||||
|
fname = file_elem.attrib['name'] |
||||||
|
name = file_elem.find('ml:identity', NS).text |
||||||
|
hash_sum = file_elem.find('ml:hash', NS).text |
||||||
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir)) |
||||||
|
return models |
||||||
|
|
||||||
|
def parseYAMLFile(yaml_filepath, save_dir): |
||||||
|
models = [] |
||||||
|
with open(yaml_filepath, 'r') as stream: |
||||||
|
data_loaded = yaml.safe_load(stream) |
||||||
|
for name, params in data_loaded.items(): |
||||||
|
load_info = params.get("load_info", None) |
||||||
|
if load_info: |
||||||
|
fname = os.path.basename(params.get("model")) |
||||||
|
hash_sum = load_info.get("sha1") |
||||||
|
url = load_info.get("url") |
||||||
|
download_sha = load_info.get("download_sha") |
||||||
|
download_name = load_info.get("download_name") |
||||||
|
archive_member = load_info.get("member") |
||||||
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir, |
||||||
|
download_name=download_name, download_sha=download_sha, archive_member=archive_member)) |
||||||
|
|
||||||
|
return models |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.') |
||||||
|
|
||||||
|
parser.add_argument('--save_dir', action="store", default=os.getcwd(), |
||||||
|
help='Path to the directory to store downloaded files') |
||||||
|
parser.add_argument('model_name', type=str, default="", nargs='?', action="store", |
||||||
|
help='name of the model to download') |
||||||
|
args = parser.parse_args() |
||||||
|
models = [] |
||||||
|
save_dir = args.save_dir |
||||||
|
selected_model_name = args.model_name |
||||||
|
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir)) |
||||||
|
models.extend(parseYAMLFile('models.yml', save_dir)) |
||||||
|
for m in models: |
||||||
|
print(m) |
||||||
|
if selected_model_name and not m.name.startswith(selected_model_name): |
||||||
|
continue |
||||||
|
print('Model: ' + selected_model_name) |
||||||
|
m.get() |
@ -1,74 +0,0 @@ |
|||||||
#!/usr/bin/env python |
|
||||||
|
|
||||||
from __future__ import print_function |
|
||||||
import hashlib |
|
||||||
import time |
|
||||||
import sys |
|
||||||
import xml.etree.ElementTree as ET |
|
||||||
if sys.version_info[0] < 3: |
|
||||||
from urllib2 import urlopen |
|
||||||
else: |
|
||||||
from urllib.request import urlopen |
|
||||||
|
|
||||||
class HashMismatchException(Exception): |
|
||||||
def __init__(self, expected, actual): |
|
||||||
Exception.__init__(self) |
|
||||||
self.expected = expected |
|
||||||
self.actual = actual |
|
||||||
def __str__(self): |
|
||||||
return 'Hash mismatch: {} vs {}'.format(self.expected, self.actual) |
|
||||||
|
|
||||||
class MetalinkDownloader(object): |
|
||||||
BUFSIZE = 10*1024*1024 |
|
||||||
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'} |
|
||||||
tick = 0 |
|
||||||
|
|
||||||
def download(self, metalink_file): |
|
||||||
status = True |
|
||||||
for file_elem in ET.parse(metalink_file).getroot().findall('ml:file', self.NS): |
|
||||||
url = file_elem.find('ml:url', self.NS).text |
|
||||||
fname = file_elem.attrib['name'] |
|
||||||
hash_sum = file_elem.find('ml:hash', self.NS).text |
|
||||||
print('*** {}'.format(fname)) |
|
||||||
try: |
|
||||||
self.verify(hash_sum, fname) |
|
||||||
except Exception as ex: |
|
||||||
print(' {}'.format(ex)) |
|
||||||
try: |
|
||||||
print(' {}'.format(url)) |
|
||||||
with open(fname, 'wb') as file_stream: |
|
||||||
self.buffered_read(urlopen(url), file_stream.write) |
|
||||||
self.verify(hash_sum, fname) |
|
||||||
except Exception as ex: |
|
||||||
print(' {}'.format(ex)) |
|
||||||
print(' FAILURE') |
|
||||||
status = False |
|
||||||
continue |
|
||||||
print(' SUCCESS') |
|
||||||
return status |
|
||||||
|
|
||||||
def print_progress(self, msg, timeout = 0): |
|
||||||
if time.time() - self.tick > timeout: |
|
||||||
print(msg, end='') |
|
||||||
sys.stdout.flush() |
|
||||||
self.tick = time.time() |
|
||||||
|
|
||||||
def buffered_read(self, in_stream, processing): |
|
||||||
self.print_progress(' >') |
|
||||||
while True: |
|
||||||
buf = in_stream.read(self.BUFSIZE) |
|
||||||
if not buf: |
|
||||||
break |
|
||||||
processing(buf) |
|
||||||
self.print_progress('>', 5) |
|
||||||
print(' done') |
|
||||||
|
|
||||||
def verify(self, hash_sum, fname): |
|
||||||
sha = hashlib.sha1() |
|
||||||
with open(fname, 'rb') as file_stream: |
|
||||||
self.buffered_read(file_stream, sha.update) |
|
||||||
if hash_sum != sha.hexdigest(): |
|
||||||
raise HashMismatchException(hash_sum, sha.hexdigest()) |
|
||||||
|
|
||||||
if __name__ == '__main__': |
|
||||||
sys.exit(0 if MetalinkDownloader().download('weights.meta4') else 1) |
|
Loading…
Reference in new issue