Skip to content

Commit

Permalink
drop torch<2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Oct 4, 2023
1 parent f84044d commit bd58005
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 17 deletions.
10 changes: 0 additions & 10 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def _track_provenance_set(x, provenance: frozenset):
@track_provenance.register(tuple)
@track_provenance.register(dict)
def _track_provenance_pytree(x, provenance: frozenset):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x

return tree_map(partial(track_provenance, provenance=provenance), x)


Expand Down Expand Up @@ -143,11 +138,6 @@ def _extract_provenance_set(x):
@extract_provenance.register(tuple)
@extract_provenance.register(dict)
def _extract_provenance_pytree(x):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x, frozenset()

flat_args, spec = tree_flatten(x)
xs = []
provenance = frozenset()
Expand Down
6 changes: 1 addition & 5 deletions pyro/optim/pytorch_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
del _PyroOptim

# Load all schedulers from PyTorch
# breaking change in torch >= 1.14: LRScheduler is new base class
if hasattr(torch.optim.lr_scheduler, "LRScheduler"):
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore
else: # for torch < 1.13, _LRScheduler is base class
_torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler # type: ignore
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore

for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
if not isinstance(_Optim, type):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"torchvision>=0.12.0",
"visdom>=0.1.4,<0.2.2", # FIXME visdom.utils is unavailable >=0.2.2
"pandas",
"pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125
"pillow>=8.3.1", # https://github.com/pytorch/pytorch/issues/61125
"scikit-learn",
"seaborn>=0.11.0",
"wget",
Expand Down Expand Up @@ -102,7 +102,7 @@
"numpy>=1.7",
"opt_einsum>=2.3.2",
"pyro-api>=0.1.1",
"torch>=1.11.0",
"torch>=2.0",
"tqdm>=4.36",
],
extras_require={
Expand Down

0 comments on commit bd58005

Please sign in to comment.