Skip to content

Commit

Permalink
ContextManager
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Nov 4, 2023
1 parent ae725ef commit 167695b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 29 deletions.
3 changes: 2 additions & 1 deletion pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __getstate__(self):
state.pop("trace")
return state

def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
# @override(TraceMessenger)
def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override]
"""
Draws posterior samples from the guide and replays the model against
those samples.
Expand Down
85 changes: 57 additions & 28 deletions pyro/poutine/messenger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from __future__ import annotations

from contextlib import ContextDecorator, contextmanager
from functools import partial
from types import TracebackType
from typing import Any, Callable, Iterator, List, Optional, Type

from .runtime import _PYRO_STACK
from .runtime import _PYRO_STACK, Message


def _context_wrap(context, fn, *args, **kwargs):
def _context_wrap(
context: Messenger,
fn: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
with context:
return fn(*args, **kwargs)

Expand All @@ -26,17 +35,25 @@ class _bound_partial(partial):
def __init__(self, func):
self.func = func

def __get__(self, instance, owner):
def __get__(
self,
instance: Optional[object],
owner: Optional[Type[object]] = None,
) -> object:
if instance is None:
return self
return partial(self.func, instance)


def unwrap(fn):
def unwrap(fn: Callable) -> Callable:
"""
Recursively unwraps poutines.
"""
# import pdb; pdb.set_trace()
while True:
if hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
continue
if isinstance(fn, _bound_partial):
fn = fn.func
continue
Expand All @@ -46,7 +63,7 @@ def unwrap(fn):
return fn


class Messenger:
class Messenger(ContextDecorator):
"""
Context manager class that modifies behavior
and adds side effects to stochastic functions
Expand All @@ -61,17 +78,15 @@ class Messenger:
Most inference operations are implemented in subclasses of this.
"""

def __call__(self, fn):
if not callable(fn):
raise ValueError(
"{} is not callable, did you mean to pass it as a keyword arg?".format(
fn
)
)
wraps = _bound_partial(partial(_context_wrap, self, fn))
return wraps

def __enter__(self):
# def __call__(self, fn: Callable) -> Callable:
# if not callable(fn):
# raise ValueError(
# f"{fn!r} is not callable, did you mean to pass it as a keyword arg?"
# )
# wraps = _bound_partial(partial(_context_wrap, self, fn))
# return wraps

def __enter__(self) -> Messenger:
"""
:returns: self
:rtype: pyro.poutine.Messenger
Expand Down Expand Up @@ -103,7 +118,12 @@ def __enter__(self):
# but it could in principle be enabled...
raise ValueError("cannot install a Messenger instance twice")

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> None:
"""
:param exc_type: exception type, e.g. ValueError
:param exc_value: exception instance?
Expand Down Expand Up @@ -146,30 +166,33 @@ def __exit__(self, exc_type, exc_value, traceback):
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()

def _reset(self):
def _reset(self) -> None:
pass

def _process_message(self, msg):
def _process_message(self, msg: Message) -> Any:
"""
:param msg: current message at a trace site
:returns: None
Process the message by calling appropriate method of itself based
on message type. The message is updated in place.
"""
method = getattr(self, "_pyro_{}".format(msg["type"]), None)
method = getattr(self, f"_pyro_{msg['type']}", None)
if method is not None:
return method(msg)
return None

def _postprocess_message(self, msg):
method = getattr(self, "_pyro_post_{}".format(msg["type"]), None)
def _postprocess_message(self, msg: Message) -> Any:
method = getattr(self, f"_pyro_post_{msg['type']}", None)
if method is not None:
return method(msg)
return None

@classmethod
def register(cls, fn=None, type=None, post=None):
def register(
cls,
fn: Optional[Callable] = None,
type: Optional[str] = None,
post: Optional[bool] = None,
) -> Callable:
"""
:param fn: function implementing operation
:param str type: name of the operation
Expand Down Expand Up @@ -197,7 +220,11 @@ def some_function(msg)
return fn

@classmethod
def unregister(cls, fn=None, type=None):
def unregister(
cls,
fn: Optional[Callable] = None,
type: Optional[str] = None,
) -> Optional[Callable]:
"""
:param fn: function implementing operation
:param str type: name of the operation
Expand Down Expand Up @@ -227,7 +254,9 @@ def unregister(cls, fn=None, type=None):


@contextmanager
def block_messengers(predicate):
def block_messengers(
predicate: Callable[[Messenger], bool]
) -> Iterator[List[Messenger]]:
"""
EXPERIMENTAL Context manager to temporarily remove matching messengers from
the _PYRO_STACK. Note this does not call the ``.__exit__()`` and
Expand Down

0 comments on commit 167695b

Please sign in to comment.