Skip to content

Commit

Permalink
Fixes to prep for weights_only default flip (#2514)
Browse files Browse the repository at this point in the history
Summary:
Some fixes for pytorch/pytorch#137602


Reviewed By: xuzhao9

Differential Revision: D64628614

Pulled By: mikaylagawarecki
  • Loading branch information
mikaylagawarecki authored and facebook-github-bot committed Oct 21, 2024
1 parent a21b30e commit 5cd7016
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
13 changes: 12 additions & 1 deletion torchbenchmark/models/functorch_maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from typing import Tuple

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -73,7 +75,16 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.model = net

root = str(Path(__file__).parent.parent)
self.meta_inputs = torch.load(f"{root}/maml_omniglot/batch.pt")
with torch.serialization.safe_globals(
[
np.core.multiarray._reconstruct,
np.ndarray,
np.dtype,
np.dtypes.Float32DType,
np.dtypes.Int64DType,
]
):
self.meta_inputs = torch.load(f"{root}/maml_omniglot/batch.pt")
self.meta_inputs = tuple(
[torch.from_numpy(i).to(self.device) for i in self.meta_inputs]
)
Expand Down
12 changes: 11 additions & 1 deletion torchbenchmark/models/maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Tuple

import higher
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -79,7 +80,16 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.model = net

root = str(Path(__file__).parent)
self.meta_inputs = torch.load(f"{root}/batch.pt")
with torch.serialization.safe_globals(
[
np.core.multiarray._reconstruct,
np.ndarray,
np.dtype,
np.dtypes.Float32DType,
np.dtypes.Int64DType,
]
):
self.meta_inputs = torch.load(f"{root}/batch.pt")
self.meta_inputs = tuple(
[torch.from_numpy(i).to(self.device) for i in self.meta_inputs]
)
Expand Down
7 changes: 7 additions & 0 deletions torchbenchmark/models/opacus_cifar10/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Tuple

import torch
Expand Down Expand Up @@ -29,7 +30,13 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
)

self.model = models.resnet18(num_classes=10)
prev_wo_envvar = os.environ.get("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", None)
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
self.model = ModuleValidator.fix(self.model)
if prev_wo_envvar is None:
del os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"]
else:
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = prev_wo_envvar
self.model = self.model.to(device)

# Cifar10 images are 32x32 and have 10 classes
Expand Down

0 comments on commit 5cd7016

Please sign in to comment.