[Fix] Fix tar weight loading error when using mp

main
Bobholamovic 2 years ago
parent d4d8ffe2dc
commit b250d8bb00
  1. 8
      .gitignore
  2. 8
      paddlers/utils/download.py

8
.gitignore vendored

@ -132,11 +132,5 @@ dmypy.json
# Pyre type checker
.pyre/
# test data
tutorials/train/change_detection/DataSet/
tutorials/train/classification/DataSet/
optic_disc_seg.tar
optic_disc_seg/
output/
/tutorials/train/**/output/
/log

@ -202,10 +202,13 @@ def download_and_decompress(url, path='.'):
local_rank = paddle.distributed.get_rank()
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
pth_path = fullname + '.path'
if nranks <= 1:
dst_dir = url2dir(url, path)
if dst_dir is not None:
with open(pth_path, 'w') as f:
f.write(dst_dir)
fullname = dst_dir
else:
lock_path = fullname + '.lock'
@ -215,9 +218,14 @@ def download_and_decompress(url, path='.'):
if local_rank == 0:
dst_dir = url2dir(url, path)
if dst_dir is not None:
with open(pth_path, 'w') as f:
f.write(dst_dir)
fullname = dst_dir
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
if os.path.exists(pth_path):
with open(pth_path, 'r') as f:
fullname = next(f)
return fullname

Loading…
Cancel
Save