Skip to content

Commit

Permalink
Update jax.tree_util.tree_map to jax.tree.map (#1821)
Browse files Browse the repository at this point in the history
* update jax.tree_util.tree_foo to jax.tree.foo

* bump minimal jax version to 0.4.25, which supports jax.tree

* fix lint issues

* also fix deprecation warning of using a_min, a_max in jnp.clip
  • Loading branch information
fehiepsi authored Jun 26, 2024
1 parent 2984b9b commit 5af9ebd
Show file tree
Hide file tree
Showing 51 changed files with 236 additions and 243 deletions.
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def main(args):
# is stored in `discrete_samples`. To merge those discrete samples into the `mcmc`
# instance, we can use the following pattern::
#
# chain_discrete_samples = jax.tree_util.tree_map(
# chain_discrete_samples = jax.tree.map(
# lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
# discrete_samples)
# mcmc.get_samples().update(discrete_samples)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), 0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)

Expand Down
4 changes: 2 additions & 2 deletions notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
" level, s, moving_sum = carry\n",
" season = s[0] * level**pow_season\n",
" exp_val = level + coef_trend * level**pow_trend + season\n",
" exp_val = jnp.clip(exp_val, a_min=0)\n",
" exp_val = jnp.clip(exp_val, 0)\n",
" # use expected vale when forecasting\n",
" y_t = jnp.where(t >= N, exp_val, y[t])\n",
"\n",
Expand All @@ -215,7 +215,7 @@
" )\n",
" level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n",
" level = level_sm * level_p + (1 - level_sm) * level\n",
" level = jnp.clip(level, a_min=0)\n",
" level = jnp.clip(level, 0)\n",
"\n",
" new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n",
" # repeat s when forecasting\n",
Expand Down
40 changes: 21 additions & 19 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections import OrderedDict
from functools import partial

import jax
from jax import device_put, lax, random
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from numpyro import handlers
from numpyro.distributions.batch_util import promote_batch_shape
Expand Down Expand Up @@ -98,7 +98,7 @@ def postprocess_message(self, msg):
fn_batch_ndim = len(fn.batch_shape)
if fn_batch_ndim < value_batch_ndims:
prepend_shapes = (1,) * (value_batch_ndims - fn_batch_ndim)
msg["fn"] = tree_map(
msg["fn"] = jax.tree.map(
lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn
)

Expand Down Expand Up @@ -140,11 +140,11 @@ def scan_enum(
history = min(history, length)
unroll_steps = min(2 * history - 1, length)
if reverse:
x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
x0 = jax.tree.map(lambda x: x[-unroll_steps:][::-1], xs)
xs_ = jax.tree.map(lambda x: x[:-unroll_steps], xs)
else:
x0 = tree_map(lambda x: x[:unroll_steps], xs)
xs_ = tree_map(lambda x: x[unroll_steps:], xs)
x0 = jax.tree.map(lambda x: x[:unroll_steps], xs)
xs_ = jax.tree.map(lambda x: x[unroll_steps:], xs)

carry_shapes = []

Expand Down Expand Up @@ -187,10 +187,12 @@ def body_fn(wrapped_carry, x, prefix=None):

# store shape of new_carry at a global variable
if len(carry_shapes) < (history + 1):
carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
carry_shapes.append(
[jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]]
)
# make new_carry have the same shape as carry
# FIXME: is this rigorous?
new_carry = tree_map(
new_carry = jax.tree.map(
lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry
)
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
Expand All @@ -204,27 +206,27 @@ def body_fn(wrapped_carry, x, prefix=None):
for i in markov(range(unroll_steps + 1), history=history):
if i < unroll_steps:
wrapped_carry, (_, y0) = body_fn(
wrapped_carry, tree_map(lambda z: z[i], x0)
wrapped_carry, jax.tree.map(lambda z: z[i], x0)
)
if i > 0:
# reshape y1, y2,... to have the same shape as y0
y0 = tree_map(
y0 = jax.tree.map(
lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0
)
y0s.append(y0)
# shapes of the first `history - 1` steps are not useful to interpret the last carry
# shape so we don't need to record them here
if (i >= history - 1) and (len(carry_shapes) < history + 1):
carry_shapes.append(
jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]
jnp.shape(x) for x in jax.tree.flatten(wrapped_carry[-1])[0]
)
else:
# this is the last rolling step
y0s = tree_map(lambda *z: jnp.stack(z, axis=0), *y0s)
y0s = jax.tree.map(lambda *z: jnp.stack(z, axis=0), *y0s)
# return early if length = unroll_steps
if length == unroll_steps:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = tree_map(device_put, wrapped_carry)
wrapped_carry = jax.tree.map(device_put, wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
)
Expand All @@ -251,20 +253,20 @@ def body_fn(wrapped_carry, x, prefix=None):
site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

# similar to carry, we need to reshape due to shape alternating in markov
ys = tree_map(
ys = jax.tree.map(
lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys
)
# then join with y0s
ys = tree_map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
ys = jax.tree.map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
# we also need to reshape `carry` to match sequential behavior
i = (length + 1) % (history + 1)
t, rng_key, carry = wrapped_carry
carry_shape = carry_shapes[i]
flatten_carry, treedef = tree_flatten(carry)
flatten_carry, treedef = jax.tree.flatten(carry)
flatten_carry = [
jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape)
]
carry = tree_unflatten(treedef, flatten_carry)
carry = jax.tree.unflatten(treedef, flatten_carry)
wrapped_carry = (t, rng_key, carry)
return wrapped_carry, (pytree_trace, ys)

Expand All @@ -282,7 +284,7 @@ def scan_wrapper(
first_available_dim=None,
):
if length is None:
length = jnp.shape(tree_flatten(xs)[0][0])[0]
length = jnp.shape(jax.tree.flatten(xs)[0][0])[0]

if enum and history > 0:
return scan_enum( # TODO: replay for enum
Expand Down Expand Up @@ -324,7 +326,7 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = tree_map(device_put, (0, rng_key, init))
wrapped_carry = jax.tree.map(device_put, (0, rng_key, init))
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from typing import Optional

import jax
from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map

from numpyro.handlers import substitute
from numpyro.infer import Predictive
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
Expand Down Expand Up @@ -99,7 +99,7 @@ def __call__(self, rng_key, *args, **kwargs):
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = tree_map(
predictive_assign = jax.tree.map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
Expand Down
10 changes: 5 additions & 5 deletions numpyro/contrib/einstein/stein_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import jax
from jax import numpy as jnp, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro.distributions import biject_to
from numpyro.distributions.constraints import real
Expand Down Expand Up @@ -64,14 +64,14 @@ def batch_ravel_pytree(pytree, nbatch_dims=0):
flat, unravel_fn = ravel_pytree(pytree)
return flat, unravel_fn, unravel_fn

shapes = tree_map(lambda x: x.shape, pytree)
flat_pytree = tree_map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
shapes = jax.tree.map(lambda x: x.shape, pytree)
flat_pytree = jax.tree.map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
flat = vmap(lambda x: ravel_pytree(x)[0])(flat_pytree)
unravel_fn = ravel_pytree(tree_map(lambda x: x[0], flat_pytree))[1]
unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], flat_pytree))[1]
return (
flat,
unravel_fn,
lambda _flat: tree_map(
lambda _flat: jax.tree.map(
lambda x, shape: x.reshape(shape), vmap(unravel_fn)(_flat), shapes
),
)
10 changes: 5 additions & 5 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from itertools import chain
import operator

import jax
from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro import handlers
from numpyro.contrib.einstein.stein_kernels import SteinKernel
Expand Down Expand Up @@ -340,10 +340,10 @@ def _update_force(attr_force, rep_force, jac):
return force.reshape(attr_force.shape)

reparam_jac = {
name: tree_map(lambda var: _nontrivial_jac(name, var), variables)
name: jax.tree.map(lambda var: _nontrivial_jac(name, var), variables)
for name, variables in unravel_pytree(particle).items()
}
jac_params = tree_map(
jac_params = jax.tree.map(
_update_force,
unravel_pytree(attr_forces),
unravel_pytree(rep_forces),
Expand All @@ -363,7 +363,7 @@ def _update_force(attr_force, rep_force, jac):
stein_param_grads = unravel_pytree_batched(particle_grads)

# 6. Return loss and gradients (based on parameter forces)
res_grads = tree_map(
res_grads = jax.tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads
Expand Down Expand Up @@ -427,7 +427,7 @@ def init(self, rng_key, *args, **kwargs):
if site["name"] in guide_init_params:
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
pval = jax.tree.map(lambda x: x[0], pval)
else:
pval = site["value"]
params[site["name"]] = transform.inv(pval)
Expand Down
11 changes: 6 additions & 5 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from copy import deepcopy
from functools import partial

import jax
from jax import random
import jax.numpy as jnp
from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
from jax.tree_util import register_pytree_node

import numpyro
import numpyro.distributions as dist
Expand Down Expand Up @@ -106,8 +107,8 @@ def flax_module(
assert set(mutable) == set(nn_state)
numpyro_mutable(name + "$state", nn_state)
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
params_flat, tree_def = jax.tree.flatten(nn_params)
nn_params = jax.tree.unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)

def apply_with_state(params, *args, **kwargs):
Expand Down Expand Up @@ -195,8 +196,8 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw
nn_params = hk.data_structures.to_mutable_dict(nn_params)
# we cast it to a mutable one to be able to set priors for parameters
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
params_flat, tree_def = jax.tree.flatten(nn_params)
nn_params = jax.tree.unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)

def apply_with_state(params, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ def is_discrete(self):
return self.support is None

def tree_flatten(self):
return jax.tree_util.tree_flatten(self.tfp_dist)
return jax.tree.flatten(self.tfp_dist)

@classmethod
def tree_unflatten(cls, aux_data, params):
fn = jax.tree_util.tree_unflatten(aux_data, params)
fn = jax.tree.unflatten(aux_data, params)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
return TFPDistribution[fn.__class__](**fn.parameters)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from collections import namedtuple
import inspect

import jax
from jax import random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map
import tensorflow_probability.substrates.jax as tfp

from numpyro.infer import init_to_uniform
Expand Down Expand Up @@ -44,7 +44,7 @@ def log_prob_fn(x):
flatten_result = vmap(lambda a: -potential_fn(unravel_fn(a)))(
jnp.reshape(x, (-1,) + jnp.shape(x)[-1:])
)
return tree_map(
return jax.tree.map(
lambda a: jnp.reshape(a, batch_shape + jnp.shape(a)[1:]), flatten_result
)
else:
Expand Down
12 changes: 6 additions & 6 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np

import jax
from jax import device_get
from jax.tree_util import tree_flatten, tree_map

__all__ = [
"autocorrelation",
Expand Down Expand Up @@ -182,7 +182,7 @@ def effective_sample_size(x):
Rho_k = np.concatenate(
[
Rho_init,
np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0),
np.minimum.accumulate(np.clip(Rho_k[1:, ...], 0, None), axis=0),
],
axis=0,
)
Expand Down Expand Up @@ -238,10 +238,10 @@ def summary(samples, prob=0.90, group_by_chain=True):
chain dimension).
"""
if not group_by_chain:
samples = tree_map(lambda x: x[None, ...], samples)
samples = jax.tree.map(lambda x: x[None, ...], samples)
if not isinstance(samples, dict):
samples = {
"Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}

summary_dict = {}
Expand Down Expand Up @@ -288,10 +288,10 @@ def print_summary(samples, prob=0.90, group_by_chain=True):
chain dimension).
"""
if not group_by_chain:
samples = tree_map(lambda x: x[None, ...], samples)
samples = jax.tree.map(lambda x: x[None, ...], samples)
if not isinstance(samples, dict):
samples = {
"Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
summary_dict = summary(samples, prob, group_by_chain=True)

Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import singledispatch
from typing import Union

import jax
import jax.numpy as jnp
from jax.tree_util import tree_map

from numpyro.distributions import constraints
from numpyro.distributions.conjugate import (
Expand Down Expand Up @@ -547,7 +547,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):
len(new_shapes_elems),
len(new_shapes_elems) + len(orig_delta_batch_shape),
)
new_base_dist = tree_map(
new_base_dist = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist
)

Expand Down
Loading

0 comments on commit 5af9ebd

Please sign in to comment.