From 167695b702ab523f31d970fdb5188b48e9bd5128 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 4 Nov 2023 05:14:00 +0000 Subject: [PATCH 1/3] ContextManager --- pyro/poutine/guide.py | 3 +- pyro/poutine/messenger.py | 85 ++++++++++++++++++++++++++------------- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index d8dcc2b572..56a6c79484 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -36,7 +36,8 @@ def __getstate__(self): state.pop("trace") return state - def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + # @override(TraceMessenger) + def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] """ Draws posterior samples from the guide and replays the model against those samples. diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 3cbb4fba4a..8efb2f1f64 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -1,13 +1,22 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from contextlib import contextmanager +from __future__ import annotations + +from contextlib import ContextDecorator, contextmanager from functools import partial +from types import TracebackType +from typing import Any, Callable, Iterator, List, Optional, Type -from .runtime import _PYRO_STACK +from .runtime import _PYRO_STACK, Message -def _context_wrap(context, fn, *args, **kwargs): +def _context_wrap( + context: Messenger, + fn: Callable, + *args: Any, + **kwargs: Any, +) -> Any: with context: return fn(*args, **kwargs) @@ -26,17 +35,25 @@ class _bound_partial(partial): def __init__(self, func): self.func = func - def __get__(self, instance, owner): + def __get__( + self, + instance: Optional[object], + owner: Optional[Type[object]] = None, + ) -> object: if instance is None: return self return partial(self.func, instance) -def unwrap(fn): +def unwrap(fn: Callable) -> Callable: """ Recursively unwraps poutines. """ + # import pdb; pdb.set_trace() while True: + if hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + continue if isinstance(fn, _bound_partial): fn = fn.func continue @@ -46,7 +63,7 @@ def unwrap(fn): return fn -class Messenger: +class Messenger(ContextDecorator): """ Context manager class that modifies behavior and adds side effects to stochastic functions @@ -61,17 +78,15 @@ class Messenger: Most inference operations are implemented in subclasses of this. """ - def __call__(self, fn): - if not callable(fn): - raise ValueError( - "{} is not callable, did you mean to pass it as a keyword arg?".format( - fn - ) - ) - wraps = _bound_partial(partial(_context_wrap, self, fn)) - return wraps - - def __enter__(self): + # def __call__(self, fn: Callable) -> Callable: + # if not callable(fn): + # raise ValueError( + # f"{fn!r} is not callable, did you mean to pass it as a keyword arg?" + # ) + # wraps = _bound_partial(partial(_context_wrap, self, fn)) + # return wraps + + def __enter__(self) -> Messenger: """ :returns: self :rtype: pyro.poutine.Messenger @@ -103,7 +118,12 @@ def __enter__(self): # but it could in principle be enabled... raise ValueError("cannot install a Messenger instance twice") - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> None: """ :param exc_type: exception type, e.g. ValueError :param exc_value: exception instance? @@ -146,10 +166,10 @@ def __exit__(self, exc_type, exc_value, traceback): for i in range(loc, len(_PYRO_STACK)): _PYRO_STACK.pop() - def _reset(self): + def _reset(self) -> None: pass - def _process_message(self, msg): + def _process_message(self, msg: Message) -> Any: """ :param msg: current message at a trace site :returns: None @@ -157,19 +177,22 @@ def _process_message(self, msg): Process the message by calling appropriate method of itself based on message type. The message is updated in place. """ - method = getattr(self, "_pyro_{}".format(msg["type"]), None) + method = getattr(self, f"_pyro_{msg['type']}", None) if method is not None: return method(msg) - return None - def _postprocess_message(self, msg): - method = getattr(self, "_pyro_post_{}".format(msg["type"]), None) + def _postprocess_message(self, msg: Message) -> Any: + method = getattr(self, f"_pyro_post_{msg['type']}", None) if method is not None: return method(msg) - return None @classmethod - def register(cls, fn=None, type=None, post=None): + def register( + cls, + fn: Optional[Callable] = None, + type: Optional[str] = None, + post: Optional[bool] = None, + ) -> Callable: """ :param fn: function implementing operation :param str type: name of the operation @@ -197,7 +220,11 @@ def some_function(msg) return fn @classmethod - def unregister(cls, fn=None, type=None): + def unregister( + cls, + fn: Optional[Callable] = None, + type: Optional[str] = None, + ) -> Optional[Callable]: """ :param fn: function implementing operation :param str type: name of the operation @@ -227,7 +254,9 @@ def unregister(cls, fn=None, type=None): @contextmanager -def block_messengers(predicate): +def block_messengers( + predicate: Callable[[Messenger], bool] +) -> Iterator[List[Messenger]]: """ EXPERIMENTAL Context manager to temporarily remove matching messengers from the _PYRO_STACK. Note this does not call the ``.__exit__()`` and From 5376be62ac6fed072474c09f5a2ab6c5659e1c98 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 5 Nov 2023 01:30:38 +0000 Subject: [PATCH 2/3] rm ContextDecorator --- pyro/poutine/guide.py | 1 - pyro/poutine/messenger.py | 30 +++++++++++++----------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index 56a6c79484..22889caa3f 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -36,7 +36,6 @@ def __getstate__(self): state.pop("trace") return state - # @override(TraceMessenger) def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] """ Draws posterior samples from the guide and replays the model against diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 8efb2f1f64..7ecf383db4 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -3,7 +3,7 @@ from __future__ import annotations -from contextlib import ContextDecorator, contextmanager +from contextlib import contextmanager from functools import partial from types import TracebackType from typing import Any, Callable, Iterator, List, Optional, Type @@ -49,11 +49,7 @@ def unwrap(fn: Callable) -> Callable: """ Recursively unwraps poutines. """ - # import pdb; pdb.set_trace() while True: - if hasattr(fn, "__wrapped__"): - fn = fn.__wrapped__ - continue if isinstance(fn, _bound_partial): fn = fn.func continue @@ -63,7 +59,7 @@ def unwrap(fn: Callable) -> Callable: return fn -class Messenger(ContextDecorator): +class Messenger: """ Context manager class that modifies behavior and adds side effects to stochastic functions @@ -78,13 +74,13 @@ class Messenger(ContextDecorator): Most inference operations are implemented in subclasses of this. """ - # def __call__(self, fn: Callable) -> Callable: - # if not callable(fn): - # raise ValueError( - # f"{fn!r} is not callable, did you mean to pass it as a keyword arg?" - # ) - # wraps = _bound_partial(partial(_context_wrap, self, fn)) - # return wraps + def __call__(self, fn: Callable) -> Callable: + if not callable(fn): + raise ValueError( + f"{fn!r} is not callable, did you mean to pass it as a keyword arg?" + ) + wraps = _bound_partial(partial(_context_wrap, self, fn)) + return wraps def __enter__(self) -> Messenger: """ @@ -169,7 +165,7 @@ def __exit__( def _reset(self) -> None: pass - def _process_message(self, msg: Message) -> Any: + def _process_message(self, msg: Message) -> None: """ :param msg: current message at a trace site :returns: None @@ -179,12 +175,12 @@ def _process_message(self, msg: Message) -> Any: """ method = getattr(self, f"_pyro_{msg['type']}", None) if method is not None: - return method(msg) + method(msg) - def _postprocess_message(self, msg: Message) -> Any: + def _postprocess_message(self, msg: Message) -> None: method = getattr(self, f"_pyro_post_{msg['type']}", None) if method is not None: - return method(msg) + method(msg) @classmethod def register( From 287f5c1c4483b2f81b3c8135f5c1fb625e9605ad Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 5 Nov 2023 20:51:34 +0000 Subject: [PATCH 3/3] address comments --- pyro/poutine/messenger.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 7ecf383db4..35245bb205 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -6,10 +6,12 @@ from contextlib import contextmanager from functools import partial from types import TracebackType -from typing import Any, Callable, Iterator, List, Optional, Type +from typing import Any, Callable, Iterator, List, Optional, Type, TypeVar, cast from .runtime import _PYRO_STACK, Message +_F = TypeVar("_F", bound=Callable) + def _context_wrap( context: Messenger, @@ -74,13 +76,13 @@ class Messenger: Most inference operations are implemented in subclasses of this. """ - def __call__(self, fn: Callable) -> Callable: + def __call__(self, fn: _F) -> _F: if not callable(fn): raise ValueError( f"{fn!r} is not callable, did you mean to pass it as a keyword arg?" ) wraps = _bound_partial(partial(_context_wrap, self, fn)) - return wraps + return cast(_F, wraps) def __enter__(self) -> Messenger: """