diff --git a/examples/mnist-pytorch/client/data.py b/examples/mnist-pytorch/client/data.py index 627f91b64..69aae55a7 100644 --- a/examples/mnist-pytorch/client/data.py +++ b/examples/mnist-pytorch/client/data.py @@ -12,16 +12,16 @@ def get_data(out_dir="data"): # Generate random int between 1 and 10 for split id, set seed for reproducibility split_id = np.random.randint(1, 11) - if not os.path.exists(f"{out_dir}/clients/{split_id}"): + if not os.path.exists(f"{abs_path}/{out_dir}/clients/{split_id}"): # create directory for data - os.makedirs(f"{out_dir}data/clients/{split_id}") + os.makedirs(f"{abs_path}/{out_dir}data/clients/{split_id}") # use requests to download the data from url url = f"https://storage.googleapis.com/public-scaleout/mnist-pytorch/data/clients/{split_id}/mnist.pt" # download into out_dir r = requests.get(url) if r.status_code == 200: - with open(f"{out_dir}/clients/{split_id}/mnist.pt", "wb") as f: + with open(f"{abs_path}/{out_dir}/clients/{split_id}/mnist.pt", "wb") as f: f.write(r.content) print(f"Downloaded data from {url}") else: