diff --git a/examples/threading.py b/examples/threading.py new file mode 100644 index 000000000..491cfd173 --- /dev/null +++ b/examples/threading.py @@ -0,0 +1,33 @@ +""" +Examples to show how to use numpyro with multithreading. +""" +from collections import defaultdict +import threading + +import jax +import jax.numpy as jnp + +import numpyro +from numpyro import handlers +import numpyro.distributions as dist +from numpyro.primitives import set_pyro_stack + + +class _StackThreadDict(defaultdict): + def current_stack(self): + thread_id = threading.get_native_id() + return self[thread_id] + + +_PYRO_THREAD_STACK = _StackThreadDict(list) +set_pyro_stack(_PYRO_THREAD_STACK) + + +def model(): + numpyro.sample("a", dist.Normal(0, 1)) + + +rng_keys = jax.random.split(jax.random.PRNGKey(0), 2) +for rng_key in rng_keys: + # creat a thread and trace + pass diff --git a/numpyro/contrib/control_flow/cond.py b/numpyro/contrib/control_flow/cond.py index 7cba2af74..c83aee2c6 100644 --- a/numpyro/contrib/control_flow/cond.py +++ b/numpyro/contrib/control_flow/cond.py @@ -7,7 +7,7 @@ from numpyro import handlers from numpyro.ops.pytree import PytreeTrace -from numpyro.primitives import _PYRO_STACK, apply_stack +from numpyro.primitives import apply_stack, get_pyro_stack def _subs_wrapper(subs_map, site): @@ -141,7 +141,7 @@ def cond(pred, true_fun, false_fun, operand): be any JAX PyTree (e.g. list / dict of arrays). :return: Output of the applied branch function. """ - if not _PYRO_STACK: + if not get_pyro_stack(): value, _ = cond_wrapper(pred, true_fun, false_fun, operand) return value diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 28579119a..8e9f1f5cb 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -17,7 +17,7 @@ from numpyro import handlers from numpyro.ops.pytree import PytreeTrace -from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack +from numpyro.primitives import Messenger, apply_stack, get_pyro_stack from numpyro.util import not_jax_tracer @@ -415,7 +415,7 @@ def g(*args, **kwargs): second output of f when scanned over the leading axis of the inputs". """ # if there are no active Messengers, we just run and return it as expected: - if not _PYRO_STACK: + if not get_pyro_stack(): (length, rng_key, carry), (pytree_trace, ys) = scan_wrapper( f, init, xs, length=length, reverse=reverse ) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index e38ae30ea..6dd937843 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -87,10 +87,10 @@ import numpyro from numpyro.distributions.distribution import COERCIONS from numpyro.primitives import ( - _PYRO_STACK, CondIndepStackFrame, Messenger, apply_stack, + get_pyro_stack, plate, ) from numpyro.util import find_stack_level, not_jax_tracer @@ -291,7 +291,7 @@ def process_message(self, msg): def __enter__(self): self.preserved_plates = frozenset( - h.name for h in _PYRO_STACK if isinstance(h, plate) + h.name for h in get_pyro_stack() if isinstance(h, plate) ) COERCIONS.append(self._coerce) return super().__enter__() diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 44c8b72e0..0f13f5221 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -918,7 +918,7 @@ def __init__(self, fn=None, method=None): self.gibbs_state = None def __enter__(self): - for handler in numpyro.primitives._PYRO_STACK[::-1]: + for handler in numpyro.primitives.get_pyro_stack()[::-1]: # the potential_fn in HMC makes the PYRO_STACK nested like trace(...); so we can extract the # unconstrained_params from the _unconstrain_reparam substitute_fn if ( diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 67d26b322..17ad660a8 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -13,14 +13,29 @@ import numpyro from numpyro.util import find_stack_level, identity -_PYRO_STACK = [] +class _StackList(list): + def current_stack(self): + return self + + +_PYRO_STACK = _StackList() CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"]) +def get_pyro_stack(): + return _PYRO_STACK.current_stack() + + +def set_pyro_stack(stack): + global _PYRO_STACK + _PYRO_STACK = stack + + def apply_stack(msg): pointer = 0 - for pointer, handler in enumerate(reversed(_PYRO_STACK)): + pyro_stack = get_pyro_stack() + for pointer, handler in enumerate(reversed(pyro_stack)): handler.process_message(msg) # When a Messenger sets the "stop" field of a message, # it prevents any Messengers above it on the stack from being applied. @@ -37,7 +52,7 @@ def apply_stack(msg): # A Messenger that sets msg["stop"] == True also prevents application # of postprocess_message by Messengers above it on the stack # via the pointer variable from the process_message loop - for handler in _PYRO_STACK[-pointer - 1 :]: + for handler in pyro_stack[-pointer - 1 :]: handler.postprocess_message(msg) return msg @@ -53,12 +68,13 @@ def __init__(self, fn=None): functools.update_wrapper(self, fn, updated=[]) def __enter__(self): - _PYRO_STACK.append(self) + get_pyro_stack().append(self) def __exit__(self, exc_type, exc_value, traceback): + pyro_stack = get_pyro_stack() if exc_type is None: - assert _PYRO_STACK[-1] is self - _PYRO_STACK.pop() + assert pyro_stack[-1] is self + pyro_stack.pop() else: # NB: this mimics Pyro exception handling # the wrapped function or block raised an exception @@ -66,10 +82,10 @@ def __exit__(self, exc_type, exc_value, traceback): # when the callee or enclosed block raises an exception, # find this handler's position in the stack, # then remove it and everything below it in the stack. - if self in _PYRO_STACK: - loc = _PYRO_STACK.index(self) - for i in range(loc, len(_PYRO_STACK)): - _PYRO_STACK.pop() + if self in pyro_stack: + loc = pyro_stack.index(self) + for i in range(loc, len(pyro_stack)): + pyro_stack.pop() def process_message(self, msg): pass @@ -174,7 +190,7 @@ def sample( raise type_error # if there are no active Messengers, we just draw a sample and return it as expected: - if not _PYRO_STACK: + if not get_pyro_stack(): return fn(rng_key=rng_key, sample_shape=sample_shape) if obs_mask is not None: @@ -228,7 +244,7 @@ def param(name, init_value=None, **kwargs): return the initial value. """ # if there are no active Messengers, we just draw a sample and return it as expected: - if not _PYRO_STACK: + if not get_pyro_stack(): assert not callable( init_value ), "A callable init_value needs to be put inside a numpyro.handlers.seed handler." @@ -270,7 +286,7 @@ def deterministic(name, value): :param str name: name of the deterministic site. :param numpy.ndarray value: deterministic value to record in the trace. """ - if not _PYRO_STACK: + if not get_pyro_stack(): return value initial_msg = {"type": "deterministic", "name": name, "value": value} @@ -295,7 +311,7 @@ def mutable(name, init_value=None): :param str name: name of the mutable site. :param init_value: mutable value to record in the trace. """ - if not _PYRO_STACK: + if not get_pyro_stack(): return init_value initial_msg = { @@ -593,7 +609,7 @@ def prng_key(): :return: a PRNG key of shape (2,) and dtype unit32. """ - if not _PYRO_STACK: + if not get_pyro_stack(): return initial_msg = { @@ -635,7 +651,7 @@ def model(data): :returns: A subsampled version of ``data`` :rtype: ~numpy.ndarray """ - if not _PYRO_STACK: + if not get_pyro_stack(): return data assert isinstance(event_dim, int) and event_dim >= 0