Skip to content

Commit

Permalink
types
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Oct 30, 2023
1 parent 4284574 commit 13690ce
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 1 addition & 5 deletions pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@
import torch
from torch.distributions import constraints, transform_to
from torch.serialization import MAP_LOCATION
from typing_extensions import TypedDict


class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]
from pyro.types import StateDict


class ParamStoreDict:
Expand Down
6 changes: 6 additions & 0 deletions pyro/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Dict, Optional, Tuple, Union

import torch
from torch.distributions import constraints
from typing_extensions import TypedDict

from pyro.poutine.indep_messenger import CondIndepStackFrame
Expand All @@ -27,3 +28,8 @@ class Message(TypedDict, total=False):
continuation: Optional[Callable[[Message], None]]
infer: Optional[Dict[str, Union[str, bool]]]
obs: Optional[torch.Tensor]


class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]

0 comments on commit 13690ce

Please sign in to comment.