Skip to content

Commit

Permalink
overload
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Nov 12, 2023
1 parent d8bbfcd commit 76415f0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
1 change: 1 addition & 0 deletions pyro/poutine/block_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _make_default_hide_fn(


def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: Message) -> bool:
return not fn(msg)

Expand Down
26 changes: 13 additions & 13 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
Tuple,
TypeVar,
Union,
overload,
)

# overload,
import torch
from typing_extensions import ParamSpec, TypedDict

Expand Down Expand Up @@ -274,18 +274,18 @@ def am_i_wrapped() -> bool:
return len(_PYRO_STACK) > 0


# @overload
# def effectful(
# fn: None = ..., type: Optional[str] = ...
# ) -> Callable[[Optional[Callable[P, T]]], Callable[P, Union[T, torch.Tensor, None]]]:
# ...
#
#
# @overload
# def effectful(
# fn: Callable[P, T] = ..., type: Optional[str] = ...
# ) -> Callable[P, Union[T, torch.Tensor, None]]:
# ...
@overload
def effectful(
fn: None = ..., type: Optional[str] = ...
) -> Callable[[Callable[P, T]], Callable[..., Union[T, torch.Tensor, None]]]:
...


@overload
def effectful(
fn: Callable[P, T] = ..., type: Optional[str] = ...
) -> Callable[..., Union[T, torch.Tensor, None]]:
...


def effectful(
Expand Down
10 changes: 5 additions & 5 deletions pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from typing import Callable, Iterator, Optional, Sequence, Union

import torch
from torch.distributions import constraints

import pyro.distributions as dist
import pyro.infer as infer
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.params import param_with_module_name
from pyro.params.param_store import ParamStoreDict
from pyro.poutine.plate_messenger import PlateMessenger
Expand Down Expand Up @@ -48,9 +48,7 @@ def clear_param_store() -> None:
_PYRO_PARAM_STORE.clear()


_param: Callable[..., torch.Tensor] = effectful(
_PYRO_PARAM_STORE.get_param, type="param"
)
_param = effectful(_PYRO_PARAM_STORE.get_param, type="param")


def param(
Expand Down Expand Up @@ -84,9 +82,11 @@ def param(
:rtype: torch.Tensor
"""
# Note effectful(-) requires the double passing of name below.
return _param(
value = _param(
name, init_tensor, constraint=constraint, event_dim=event_dim, name=name
)
assert isinstance(value, torch.Tensor)
return value


def _masked_observe(
Expand Down

0 comments on commit 76415f0

Please sign in to comment.