Skip to content

Commit

Permalink
Type annotate pyro.poutine.runtime (#3288)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Nov 1, 2023
1 parent 23e6470 commit ae725ef
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from pyro.distributions import Categorical
from pyro.distributions import Categorical # type: ignore[attr-defined]
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.ops.indexing import Vindex
from pyro.util import ignore_jit_warnings
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_posterior(
"""
raise NotImplementedError

def upstream_value(self, name: str) -> torch.Tensor:
def upstream_value(self, name: str):
"""
For use in :meth:`get_posterior` .
Expand Down
82 changes: 54 additions & 28 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,48 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import functools
from typing import Dict
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from typing_extensions import TypedDict

from pyro.params.param_store import ( # noqa: F401
_MODULE_NAMESPACE_DIVIDER,
ParamStoreDict,
)

if TYPE_CHECKING:
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger

# the global pyro stack
_PYRO_STACK = []
_PYRO_STACK: List[Messenger] = []

# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()


class Message(TypedDict, total=False):
type: Optional[str]
name: str
fn: Callable
is_observed: bool
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
scale: float
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
done: bool
stop: bool
continuation: Optional[Callable[[Message], None]]
infer: Optional[Dict[str, Union[str, bool]]]
obs: Optional[torch.Tensor]


class _DimAllocator:
"""
Dimension allocator for internal use by :class:`plate`.
Expand All @@ -24,26 +51,25 @@ class _DimAllocator:
Note that dimensions are indexed from the right, e.g. -1, -2.
"""

def __init__(self):
self._stack = [] # in reverse orientation of log_prob.shape
def __init__(self) -> None:
# in reverse orientation of log_prob.shape
self._stack: List[Optional[str]] = []

def allocate(self, name, dim):
def allocate(self, name: str, dim: Optional[int]) -> int:
"""
Allocate a dimension to an :class:`plate` with given name.
Dim should be either None for automatic allocation or a negative
integer for manual allocation.
"""
if name in self._stack:
raise ValueError('duplicate plate "{}"'.format(name))
raise ValueError(f"duplicate plate '{name}'")
if dim is None:
# Automatically designate the rightmost available dim for allocation.
dim = -1
while -dim <= len(self._stack) and self._stack[-1 - dim] is not None:
dim -= 1
elif dim >= 0:
raise ValueError(
"Expected dim < 0 to index from the right, actual {}".format(dim)
)
raise ValueError(f"Expected dim < 0 to index from the right, actual {dim}")

# Allocate the requested dimension.
while dim < -len(self._stack):
Expand All @@ -64,7 +90,7 @@ def allocate(self, name, dim):
self._stack[-1 - dim] = name
return dim

def free(self, name, dim):
def free(self, name: str, dim: int) -> None:
"""
Free a dimension.
"""
Expand All @@ -88,7 +114,7 @@ class _EnumAllocator:
Note that ids are simply nonnegative integers here.
"""

def set_first_available_dim(self, first_available_dim):
def set_first_available_dim(self, first_available_dim: int) -> None:
"""
Set the first available dim, which should be to the left of all
:class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should
Expand All @@ -98,9 +124,9 @@ def set_first_available_dim(self, first_available_dim):
assert first_available_dim < 0, first_available_dim
self.next_available_dim = first_available_dim
self.next_available_id = 0
self.dim_to_id = {} # only the global ids
self.dim_to_id: Dict[int, int] = {} # only the global ids

def allocate(self, scope_dims=None):
def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int, int]:
"""
Allocate a new recyclable dim and a unique id.
Expand Down Expand Up @@ -146,26 +172,28 @@ class NonlocalExit(Exception):
Used by poutine.EscapeMessenger to return site information.
"""

def __init__(self, site, *args, **kwargs):
def __init__(self, site: Message, *args, **kwargs) -> None:
"""
:param site: message at a pyro site constructor.
Just stores the input site.
"""
super().__init__(*args, **kwargs)
self.site = site

def reset_stack(self):
def reset_stack(self) -> None:
"""
Reset the state of the frames remaining in the stack.
Necessary for multiple re-executions in poutine.queue.
"""
from pyro.poutine.block_messenger import BlockMessenger

for frame in reversed(_PYRO_STACK):
frame._reset()
if type(frame).__name__ == "BlockMessenger" and frame.hide_fn(self.site):
if isinstance(frame, BlockMessenger) and frame.hide_fn(self.site):
break


def default_process_message(msg):
def default_process_message(msg: Message) -> None:
"""
Default method for processing messages in inference.
Expand All @@ -174,15 +202,15 @@ def default_process_message(msg):
"""
if msg["done"] or msg["is_observed"] or msg["value"] is not None:
msg["done"] = True
return msg
return

msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

# after fn has been called, update msg to prevent it from being called again.
msg["done"] = True


def apply_stack(initial_msg):
def apply_stack(initial_msg: Message) -> None:
"""
Execute the effect stack at a single site according to the following scheme:
Expand Down Expand Up @@ -223,8 +251,6 @@ def apply_stack(initial_msg):
if cont is not None:
cont(msg)

return None


def am_i_wrapped():
"""
Expand All @@ -234,7 +260,7 @@ def am_i_wrapped():
return len(_PYRO_STACK) > 0


def effectful(fn=None, type=None):
def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable:
"""
:param fn: function or callable that performs an effectful computation
:param str type: the type label of the operation, e.g. `"sample"`
Expand All @@ -247,7 +273,7 @@ def effectful(fn=None, type=None):
if getattr(fn, "_is_effectful", None):
return fn

assert type is not None, "must provide a type label for operation {}".format(fn)
assert type is not None, f"must provide a type label for operation {fn}"
assert type != "message", "cannot use 'message' as keyword"

@functools.wraps(fn)
Expand Down Expand Up @@ -281,11 +307,11 @@ def _fn(*args, **kwargs):
apply_stack(msg)
return msg["value"]

_fn._is_effectful = True
_fn._is_effectful = True # type: ignore[attr-defined]
return _fn


def _inspect() -> Dict[str, object]:
def _inspect() -> Message:
"""
EXPERIMENTAL Inspect the Pyro stack.
Expand All @@ -295,7 +321,7 @@ def _inspect() -> Dict[str, object]:
:returns: A message with all effects applied.
:rtype: dict
"""
msg = {
msg: Message = {
"type": "inspect",
"name": "_pyro_inspect",
"fn": lambda: True,
Expand All @@ -315,7 +341,7 @@ def _inspect() -> Dict[str, object]:
return msg


def get_mask():
def get_mask() -> Union[bool, torch.Tensor, None]:
"""
Records the effects of enclosing ``poutine.mask`` handlers.
Expand All @@ -335,7 +361,7 @@ def model():
return _inspect()["mask"]


def get_plates() -> tuple:
def get_plates() -> Tuple[CondIndepStackFrame, ...]:
"""
Records the effects of enclosing ``pyro.plate`` contexts.
Expand Down
7 changes: 3 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ column_limit = 120

[mypy]
python_version = 3.8
explicit_package_bases = True
warn_return_any = True
warn_unused_configs = True
warn_incomplete_stub = True
Expand Down Expand Up @@ -77,11 +78,9 @@ warn_unused_ignores = True
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.optm.*]
warn_unused_ignores = True

[mypy-pyro.poutine.*]
[mypy-pyro.optim.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.util.*]
ignore_errors = True
Expand Down

0 comments on commit ae725ef

Please sign in to comment.