diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 4e2ba9e74d..c191227f67 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -19,12 +19,8 @@ import torch from torch.distributions import constraints, transform_to from torch.serialization import MAP_LOCATION -from typing_extensions import TypedDict - -class StateDict(TypedDict): - params: Dict[str, torch.Tensor] - constraints: Dict[str, constraints.Constraint] +from pyro.types import StateDict class ParamStoreDict: diff --git a/pyro/types.py b/pyro/types.py index 9010178efb..3a62caf8e3 100644 --- a/pyro/types.py +++ b/pyro/types.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, Optional, Tuple, Union import torch +from torch.distributions import constraints from typing_extensions import TypedDict from pyro.poutine.indep_messenger import CondIndepStackFrame @@ -27,3 +28,8 @@ class Message(TypedDict, total=False): continuation: Optional[Callable[[Message], None]] infer: Optional[Dict[str, Union[str, bool]]] obs: Optional[torch.Tensor] + + +class StateDict(TypedDict): + params: Dict[str, torch.Tensor] + constraints: Dict[str, constraints.Constraint]