From 13690ceb6882507ff87a4d20a0f5d0489cc2a8cd Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 30 Oct 2023 02:06:01 +0000 Subject: [PATCH] types --- pyro/params/param_store.py | 6 +----- pyro/types.py | 6 ++++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 4e2ba9e74d..c191227f67 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -19,12 +19,8 @@ import torch from torch.distributions import constraints, transform_to from torch.serialization import MAP_LOCATION -from typing_extensions import TypedDict - -class StateDict(TypedDict): - params: Dict[str, torch.Tensor] - constraints: Dict[str, constraints.Constraint] +from pyro.types import StateDict class ParamStoreDict: diff --git a/pyro/types.py b/pyro/types.py index 9010178efb..3a62caf8e3 100644 --- a/pyro/types.py +++ b/pyro/types.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, Optional, Tuple, Union import torch +from torch.distributions import constraints from typing_extensions import TypedDict from pyro.poutine.indep_messenger import CondIndepStackFrame @@ -27,3 +28,8 @@ class Message(TypedDict, total=False): continuation: Optional[Callable[[Message], None]] infer: Optional[Dict[str, Union[str, bool]]] obs: Optional[torch.Tensor] + + +class StateDict(TypedDict): + params: Dict[str, torch.Tensor] + constraints: Dict[str, constraints.Constraint]