diff --git a/torch/hub.py b/torch/hub.py index 543519f7ed084..2704060563ddd 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -378,7 +378,12 @@ def _download_url_to_file(url, dst, hash_prefix, progress): if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) - f = tempfile.NamedTemporaryFile(delete=False) + # We deliberately save it in a temp file and move it after + # download is complete. This prevents a local working checkpoint + # being overriden by a broken download. + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + try: if hash_prefix is not None: sha256 = hashlib.sha256()