Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multithreading for PYRO_STACK #1343

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Examples to show how to use numpyro with multithreading.
"""
from collections import defaultdict
import threading

import jax
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.primitives import set_pyro_stack


class _StackThreadDict(defaultdict):
def current_stack(self):
thread_id = threading.get_native_id()
return self[thread_id]


_PYRO_THREAD_STACK = _StackThreadDict(list)
set_pyro_stack(_PYRO_THREAD_STACK)


def model():
numpyro.sample("a", dist.Normal(0, 1))


rng_keys = jax.random.split(jax.random.PRNGKey(0), 2)
for rng_key in rng_keys:
# creat a thread and trace
pass
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from numpyro import handlers
from numpyro.ops.pytree import PytreeTrace
from numpyro.primitives import _PYRO_STACK, apply_stack
from numpyro.primitives import apply_stack, get_pyro_stack


def _subs_wrapper(subs_map, site):
Expand Down Expand Up @@ -141,7 +141,7 @@ def cond(pred, true_fun, false_fun, operand):
be any JAX PyTree (e.g. list / dict of arrays).
:return: Output of the applied branch function.
"""
if not _PYRO_STACK:
if not get_pyro_stack():
value, _ = cond_wrapper(pred, true_fun, false_fun, operand)
return value

Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from numpyro import handlers
from numpyro.ops.pytree import PytreeTrace
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack
from numpyro.primitives import Messenger, apply_stack, get_pyro_stack
from numpyro.util import not_jax_tracer


Expand Down Expand Up @@ -415,7 +415,7 @@ def g(*args, **kwargs):
second output of f when scanned over the leading axis of the inputs".
"""
# if there are no active Messengers, we just run and return it as expected:
if not _PYRO_STACK:
if not get_pyro_stack():
(length, rng_key, carry), (pytree_trace, ys) = scan_wrapper(
f, init, xs, length=length, reverse=reverse
)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@
import numpyro
from numpyro.distributions.distribution import COERCIONS
from numpyro.primitives import (
_PYRO_STACK,
CondIndepStackFrame,
Messenger,
apply_stack,
get_pyro_stack,
plate,
)
from numpyro.util import find_stack_level, not_jax_tracer
Expand Down Expand Up @@ -291,7 +291,7 @@ def process_message(self, msg):

def __enter__(self):
self.preserved_plates = frozenset(
h.name for h in _PYRO_STACK if isinstance(h, plate)
h.name for h in get_pyro_stack() if isinstance(h, plate)
)
COERCIONS.append(self._coerce)
return super().__enter__()
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def __init__(self, fn=None, method=None):
self.gibbs_state = None

def __enter__(self):
for handler in numpyro.primitives._PYRO_STACK[::-1]:
for handler in numpyro.primitives.get_pyro_stack()[::-1]:
# the potential_fn in HMC makes the PYRO_STACK nested like trace(...); so we can extract the
# unconstrained_params from the _unconstrain_reparam substitute_fn
if (
Expand Down
48 changes: 32 additions & 16 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,29 @@
import numpyro
from numpyro.util import find_stack_level, identity

_PYRO_STACK = []

class _StackList(list):
def current_stack(self):
return self


_PYRO_STACK = _StackList()
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])


def get_pyro_stack():
return _PYRO_STACK.current_stack()


def set_pyro_stack(stack):
global _PYRO_STACK
_PYRO_STACK = stack


def apply_stack(msg):
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
pyro_stack = get_pyro_stack()
for pointer, handler in enumerate(reversed(pyro_stack)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
Expand All @@ -37,7 +52,7 @@ def apply_stack(msg):
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer - 1 :]:
for handler in pyro_stack[-pointer - 1 :]:
handler.postprocess_message(msg)
return msg

Expand All @@ -53,23 +68,24 @@ def __init__(self, fn=None):
functools.update_wrapper(self, fn, updated=[])

def __enter__(self):
_PYRO_STACK.append(self)
get_pyro_stack().append(self)

def __exit__(self, exc_type, exc_value, traceback):
pyro_stack = get_pyro_stack()
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
assert pyro_stack[-1] is self
pyro_stack.pop()
else:
# NB: this mimics Pyro exception handling
# the wrapped function or block raised an exception
# handler exception handling:
# when the callee or enclosed block raises an exception,
# find this handler's position in the stack,
# then remove it and everything below it in the stack.
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()
if self in pyro_stack:
loc = pyro_stack.index(self)
for i in range(loc, len(pyro_stack)):
pyro_stack.pop()

def process_message(self, msg):
pass
Expand Down Expand Up @@ -174,7 +190,7 @@ def sample(
raise type_error

# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
if not get_pyro_stack():
return fn(rng_key=rng_key, sample_shape=sample_shape)

if obs_mask is not None:
Expand Down Expand Up @@ -228,7 +244,7 @@ def param(name, init_value=None, **kwargs):
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
if not get_pyro_stack():
assert not callable(
init_value
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler."
Expand Down Expand Up @@ -270,7 +286,7 @@ def deterministic(name, value):
:param str name: name of the deterministic site.
:param numpy.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
if not get_pyro_stack():
return value

initial_msg = {"type": "deterministic", "name": name, "value": value}
Expand All @@ -295,7 +311,7 @@ def mutable(name, init_value=None):
:param str name: name of the mutable site.
:param init_value: mutable value to record in the trace.
"""
if not _PYRO_STACK:
if not get_pyro_stack():
return init_value

initial_msg = {
Expand Down Expand Up @@ -593,7 +609,7 @@ def prng_key():

:return: a PRNG key of shape (2,) and dtype unit32.
"""
if not _PYRO_STACK:
if not get_pyro_stack():
return

initial_msg = {
Expand Down Expand Up @@ -635,7 +651,7 @@ def model(data):
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
"""
if not _PYRO_STACK:
if not get_pyro_stack():
return data

assert isinstance(event_dim, int) and event_dim >= 0
Expand Down