From a93ac1a6a8b967a3ab79acda16a327aa62f4a560 Mon Sep 17 00:00:00 2001 From: liugang Date: Sat, 23 Sep 2017 16:01:11 +0800 Subject: [PATCH] Improve download package Using wrap mechanism in enterprise environment, Some package is very large, example, sdk package from BSP vendor. so: - open file in the output directory with a temporary name - download a chunk, update hash calculation, write chunk to file - when finished close file and check the hash - if hash is incorrect, delete temp file and raise error - if hash is correct, atomically rename temp file to final file fix issue: #2358 --- mesonbuild/wrap/wrap.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/mesonbuild/wrap/wrap.py b/mesonbuild/wrap/wrap.py index eb6d228a3..14529abf3 100644 --- a/mesonbuild/wrap/wrap.py +++ b/mesonbuild/wrap/wrap.py @@ -230,6 +230,8 @@ class Resolver: def get_data(self, url): blocksize = 10 * 1024 + h = hashlib.sha256() + tmpfile = tempfile.NamedTemporaryFile(mode='wb', dir=self.cachedir, delete=False) if url.startswith('https://wrapdb.mesonbuild.com'): resp = open_wrapdburl(url) else: @@ -241,26 +243,34 @@ class Resolver: dlsize = None if dlsize is None: print('Downloading file of unknown size.') - return resp.read() + while True: + block = resp.read(blocksize) + if block == b'': + break + h.update(block) + tmpfile.write(block) + hashvalue = h.hexdigest() + return hashvalue, tmpfile.name print('Download size:', dlsize) print('Downloading: ', end='') sys.stdout.flush() printed_dots = 0 - blocks = [] downloaded = 0 while True: block = resp.read(blocksize) if block == b'': break downloaded += len(block) - blocks.append(block) + h.update(block) + tmpfile.write(block) ratio = int(downloaded / dlsize * 10) while printed_dots < ratio: print('.', end='') sys.stdout.flush() printed_dots += 1 print('') - return b''.join(blocks) + hashvalue = h.hexdigest() + return hashvalue, tmpfile.name def get_hash(self, data): h = hashlib.sha256() @@ -275,24 +285,22 @@ class Resolver: else: srcurl = p.get('source_url') mlog.log('Downloading', mlog.bold(packagename), 'from', mlog.bold(srcurl)) - srcdata = self.get_data(srcurl) - dhash = self.get_hash(srcdata) + dhash, tmpfile = self.get_data(srcurl) expected = p.get('source_hash') if dhash != expected: + os.remove(tmpfile) raise RuntimeError('Incorrect hash for source %s:\n %s expected\n %s actual.' % (packagename, expected, dhash)) - with open(ofname, 'wb') as f: - f.write(srcdata) + os.rename(tmpfile, ofname) if p.has_patch(): purl = p.get('patch_url') mlog.log('Downloading patch from', mlog.bold(purl)) - pdata = self.get_data(purl) - phash = self.get_hash(pdata) + phash, tmpfile = self.get_data(purl) expected = p.get('patch_hash') if phash != expected: + os.remove(tmpfile) raise RuntimeError('Incorrect hash for patch %s:\n %s expected\n %s actual' % (packagename, expected, phash)) filename = os.path.join(self.cachedir, p.get('patch_filename')) - with open(filename, 'wb') as f: - f.write(pdata) + os.rename(tmpfile, filename) else: mlog.log('Package does not require patch.')