From a5f52c508bfdb14a5d96d57b9bb2ee906b85e730 Mon Sep 17 00:00:00 2001 From: Nikhil Shenoy Date: Wed, 27 Sep 2023 17:23:54 +0000 Subject: [PATCH] Modified name and subset npz assignment to fix mp --- src/openqdc/datasets/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 1c1a2c6..392144d 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -159,10 +159,12 @@ def read_preprocess(self): for key in ["name", "subset"]: filename = p_join(self.preprocess_path, f"{key}.npz") pull_locally(filename) - # with open(filename, "rb") as f: - self.data[key] = np.load(open(filename, "rb")) - for k in self.data[key]: - print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}") + self.data[key] = dict() + with open(filename, "rb") as f: + tmp = np.load(f) + for k in tmp: + self.data[key][k] = tmp[k] + print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}") def is_preprocessed(self): predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap")) for key in self.data_keys]