Skip to content

Commit

Permalink
set weights_only parameter of torch.load to False
Browse files Browse the repository at this point in the history
- #48
  • Loading branch information
aditya0by0 committed Oct 5, 2024
1 parent e17a9c0 commit 7fc96a9
Show file tree
Hide file tree
Showing 16 changed files with 68 additions and 29 deletions.
8 changes: 6 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def __init__(
# Load pretrained checkpoint if provided
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin, map_location=self.device)
model_dict = torch.load(
fin, map_location=self.device, weights_only=False
)
if load_prefix:
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
else:
Expand Down Expand Up @@ -414,7 +416,9 @@ def __init__(self, cone_dimensions=20, **kwargs):
model_prefix = kwargs.get("load_prefix", None)
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin, map_location=self.device)
model_dict = torch.load(
fin, map_location=self.device, weights_only=False
)
if model_prefix:
state_dict = {
str(k)[len(model_prefix) :]: v
Expand Down
14 changes: 10 additions & 4 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def load_processed_data(
filename = self.processed_file_names_dict[kind]
except NotImplementedError:
filename = f"{kind}.pt"
return torch.load(os.path.join(self.processed_dir, filename))
return torch.load(
os.path.join(self.processed_dir, filename), weights_only=False
)

def dataloader(self, kind: str, **kwargs) -> DataLoader:
"""
Expand Down Expand Up @@ -519,7 +521,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
DataLoader: DataLoader object for the specified subset.
"""
subdatasets = [
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"))
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False)
for s in self.subsets
]
dataset = [
Expand Down Expand Up @@ -1022,7 +1024,9 @@ def _retrieve_splits_from_csv(self) -> None:
splits_df = pd.read_csv(self.splits_file_path)

filename = self.processed_file_names_dict["data"]
data = torch.load(os.path.join(self.processed_dir, filename))
data = torch.load(
os.path.join(self.processed_dir, filename), weights_only=False
)
df_data = pd.DataFrame(data)

train_ids = splits_df[splits_df["split"] == "train"]["id"]
Expand Down Expand Up @@ -1081,7 +1085,9 @@ def load_processed_data(

# If filename is provided
try:
return torch.load(os.path.join(self.processed_dir, filename))
return torch.load(
os.path.join(self.processed_dir, filename), weights_only=False
)
except FileNotFoundError:
raise FileNotFoundError(f"File {filename} doesn't exist")

Expand Down
7 changes: 5 additions & 2 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
try:
filename = self.processed_file_names_dict["data"]
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
data_chebi_version = torch.load(
os.path.join(self.processed_dir, filename), weights_only=False
)
except FileNotFoundError:
raise FileNotFoundError(
f"File data.pt doesn't exists. "
Expand All @@ -428,7 +430,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
data_chebi_train_version = torch.load(
os.path.join(
self._chebi_version_train_obj.processed_dir, filename_train
)
),
weights_only=False,
)
except FileNotFoundError:
raise FileNotFoundError(
Expand Down
4 changes: 3 additions & 1 deletion chebai/preprocessing/datasets/go_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
try:
filename = self.processed_file_names_dict["data"]
data_go = torch.load(os.path.join(self.processed_dir, filename))
data_go = torch.load(
os.path.join(self.processed_dir, filename), weights_only=False
)
except FileNotFoundError:
raise FileNotFoundError(
f"File data.pt doesn't exists. "
Expand Down
4 changes: 2 additions & 2 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,10 +891,10 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
DataLoader: DataLoader instance.
"""
labeled_data = torch.load(
os.path.join(self.labeled.processed_dir, f"{kind}.pt")
os.path.join(self.labeled.processed_dir, f"{kind}.pt"), weights_only=False
)
unlabeled_data = torch.load(
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt")
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt"), weights_only=False
)
if self.data_limit is not None:
labeled_data = labeled_data[: self.data_limit]
Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/migration/chebi_data_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _combine_pt_splits(
df_list: List[pd.DataFrame] = []
for split, file_name in old_splits_file_names.items():
file_path = os.path.join(old_dir, file_name)
file_df = pd.DataFrame(torch.load(file_path))
file_df = pd.DataFrame(torch.load(file_path, weights_only=False))
df_list.append(file_df)

return pd.concat(df_list, ignore_index=True)
Expand Down
4 changes: 3 additions & 1 deletion chebai/result/analyse_sem.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ def run_all(
os.path.join(buffer_dir_smoothed, "preds000.pt")
):
preds = torch.load(
os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE
os.path.join(buffer_dir_smoothed, "preds000.pt"),
DEVICE,
weights_only=False,
)
labels = None
else:
Expand Down
2 changes: 1 addition & 1 deletion chebai/result/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _generate_predictions(self, data_path, raw=False, **kwargs):
else:
data_tuples = [
(x.get("raw_features", x["ident"]), x["ident"], x)
for x in torch.load(data_path)
for x in torch.load(data_path, weights_only=False)
]

for raw_features, ident, row in tqdm.tqdm(data_tuples):
Expand Down
2 changes: 1 addition & 1 deletion chebai/result/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def evaluate_model(logs_base_path, model_filename, data_module):
collate = data_module.reader.COLLATOR()
test_file = "test.pt"
data_path = os.path.join(data_module.processed_dir, test_file)
data_list = torch.load(data_path)
data_list = torch.load(data_path, weights_only=False)
preds_list = []
labels_list = []

Expand Down
2 changes: 2 additions & 0 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def load_results_from_buffer(
torch.load(
os.path.join(buffer_dir, filename),
map_location=torch.device(device),
weights_only=False,
)
)
i += 1
Expand All @@ -194,6 +195,7 @@ def load_results_from_buffer(
torch.load(
os.path.join(buffer_dir, filename),
map_location=torch.device(device),
weights_only=False,
)
)
i += 1
Expand Down
6 changes: 5 additions & 1 deletion tests/testCustomBalancedAccuracyMetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def test_metric_against_realistic_data(self) -> None:

# load single file to get the num of labels for metric class instantiation
labels = torch.load(
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
f"{directory_path}/labels{0:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
num_labels = labels.shape[1]
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
Expand All @@ -58,10 +60,12 @@ def test_metric_against_realistic_data(self) -> None:
labels = torch.load(
f"{directory_path}/labels{i:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
preds = torch.load(
f"{directory_path}/preds{i:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
balanced_acc_custom.update(preds, labels)

Expand Down
6 changes: 5 additions & 1 deletion tests/testCustomMacroF1Metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def test_metric_against_realistic_data(self) -> None:

# Load single file to get the number of labels for metric class instantiation
labels = torch.load(
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
f"{directory_path}/labels{0:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
num_labels = labels.shape[1]
macro_f1_custom = MacroF1(num_labels=num_labels)
Expand All @@ -130,10 +132,12 @@ def test_metric_against_realistic_data(self) -> None:
labels = torch.load(
f"{directory_path}/labels{i:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
preds = torch.load(
f"{directory_path}/preds{i:03d}.pt",
map_location=torch.device(self.device),
weights_only=False,
)
macro_f1_standard.update(preds, labels)
macro_f1_custom.update(preds, labels)
Expand Down
12 changes: 9 additions & 3 deletions tests/testPubChemData.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ def getDataSplitsOverlaps(cls) -> None:
processed_path = os.path.join(os.getcwd(), cls.pubChem.processed_dir)
print(f"Checking Data from - {processed_path}")

train_set = torch.load(os.path.join(processed_path, "train.pt"))
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
test_set = torch.load(os.path.join(processed_path, "test.pt"))
train_set = torch.load(
os.path.join(processed_path, "train.pt"), weights_only=False
)
val_set = torch.load(
os.path.join(processed_path, "validation.pt"), weights_only=False
)
test_set = torch.load(
os.path.join(processed_path, "test.pt"), weights_only=False
)

train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
Expand Down
12 changes: 9 additions & 3 deletions tests/testTox21MolNetData.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ def getDataSplitsOverlaps(cls) -> None:
processed_path = os.path.join(os.getcwd(), cls.tox21.processed_dir)
print(f"Checking Data from - {processed_path}")

train_set = torch.load(os.path.join(processed_path, "train.pt"))
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
test_set = torch.load(os.path.join(processed_path, "test.pt"))
train_set = torch.load(
os.path.join(processed_path, "train.pt"), weights_only=False
)
val_set = torch.load(
os.path.join(processed_path, "validation.pt"), weights_only=False
)
test_set = torch.load(
os.path.join(processed_path, "test.pt"), weights_only=False
)

train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
Expand Down
10 changes: 5 additions & 5 deletions tutorials/demo_process_results.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@
"# check if pretraining datasets overlap\n",
"dm = PubChemDeepSMILES()\n",
"processed_path = dm.processed_dir\n",
"test_set = torch.load(os.path.join(processed_path, \"test.pt\"))\n",
"val_set = torch.load(os.path.join(processed_path, \"validation.pt\"))\n",
"train_set = torch.load(os.path.join(processed_path, \"train.pt\"))\n",
"test_set = torch.load(os.path.join(processed_path, \"test.pt\"), weights_only=False)\n",
"val_set = torch.load(os.path.join(processed_path, \"validation.pt\"), weights_only=False)\n",
"train_set = torch.load(os.path.join(processed_path, \"train.pt\"), weights_only=False)\n",
"print(processed_path)\n",
"test_smiles = [entry[\"features\"] for entry in test_set]\n",
"val_smiles = [entry[\"features\"] for entry in val_set]\n",
Expand Down Expand Up @@ -320,7 +320,7 @@
"data_module_v200 = ChEBIOver100()\n",
"data_module_v148 = ChEBIOver100(chebi_version_train=148)\n",
"data_module_v227 = ChEBIOver100(chebi_version_train=227)\n",
"# dataset = torch.load(data_path)\n",
"# dataset = torch.load(data_path, weights_only=False)\n",
"# processors = [CustomResultsProcessor()]\n",
"# factory = ResultFactory(model, data_module, processors)\n",
"# factory.execute(data_path)"
Expand Down Expand Up @@ -653,7 +653,7 @@
" if test_file is None:\n",
" test_file = data_module.processed_file_names_dict[\"test\"]\n",
" data_path = os.path.join(data_module.processed_dir, test_file)\n",
" data_list = torch.load(data_path)\n",
" data_list = torch.load(data_path, weights_only=False)\n",
" preds_list = []\n",
" labels_list = []\n",
" # if common_classes_mask is not N\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/process_results_old_chebi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
" if test_file is None:\n",
" test_file = data_module.processed_file_names_dict[\"test\"]\n",
" data_path = os.path.join(data_module.processed_dir, test_file)\n",
" data_list = torch.load(data_path)\n",
" data_list = torch.load(data_path, weights_only=False)\n",
" preds_list = []\n",
" labels_list = []\n",
"\n",
Expand Down

0 comments on commit 7fc96a9

Please sign in to comment.