Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 633778219
Change-Id: I8942b020baf4ab93fae0ab3eb47a9fddb0f245ac
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 15, 2024
1 parent 020facd commit 8c167c3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion saxml/server/jax/servable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _create_aval(x):
return jax.ShapeDtypeStruct(x.shape, dtype)

compiled = step_fn.lower(
jax.tree_map(_create_aval, train_state.mdl_vars),
jax.tree.map(_create_aval, train_state.mdl_vars),
inputs_shape_dtype,
).compile(compiler_options=compiler_options)
return compiled
Expand Down
2 changes: 1 addition & 1 deletion saxml/server/pax/lm/servable_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _pad_fn(x):
padded = jnp.pad(x, paddings)
return padded

return jax.tree_map(_pad_fn, result)
return jax.tree.map(_pad_fn, result)

def post_process_branch_outputs(
self, outputs: NestedJTensor, branch_key: int
Expand Down
18 changes: 9 additions & 9 deletions saxml/server/pax/servable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def jax_func(
) -> NestedJTensor:
if self._model.fprop_dtype == jnp.bfloat16:
# Convert float inputs/vars if fprop dtype is bfloat16.
batched_inputs, mdl_vars = jax.tree_map(
batched_inputs, mdl_vars = jax.tree.map(
(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x),
(batched_inputs, mdl_vars),
)
Expand Down Expand Up @@ -294,7 +294,7 @@ def maybe_to_float32(x):
return x.astype(jnp.float32)
return x

outputs = jax.tree_map(maybe_to_float32, outputs)
outputs = jax.tree.map(maybe_to_float32, outputs)
return self.fetch_output(outputs, batched_inputs)

def unload(self) -> None:
Expand Down Expand Up @@ -484,7 +484,7 @@ def maybe_to_bfloat16_dtype(x):
return jax.ShapeDtypeStruct(x.shape, jnp.bfloat16)
return x

train_state_global_shapes = jax.tree_map(
train_state_global_shapes = jax.tree.map(
maybe_to_bfloat16_dtype, train_state_global_shapes
)

Expand Down Expand Up @@ -538,7 +538,7 @@ def maybe_to_bfloat16_dtype(x):
vars_weight_params, discard_opt_states=discard_opt_states
)
).mdl_vars
mdl_var_unpadded_shapes = jax.tree_map(
mdl_var_unpadded_shapes = jax.tree.map(
lambda x: x.shape, mdl_var_unpadded_shapes
)

Expand Down Expand Up @@ -567,7 +567,7 @@ def quant_pspec_fn(mdl_vars_to_quant, prng_keys):
or model_p.dtype == jnp.bfloat16
):
# Convert float inputs/vars if fprop dtype is bfloat16.
mdl_vars_to_quant = jax.tree_map(convert, mdl_vars_to_quant)
mdl_vars_to_quant = jax.tree.map(convert, mdl_vars_to_quant)
k1, k2, prng_keys = jax.random.split(prng_keys, num=3)
return jax_task.model.apply(
mdl_vars_to_quant,
Expand All @@ -586,7 +586,7 @@ def quant_pspec_fn(mdl_vars_to_quant, prng_keys):
)
new_pspec, _ = pjit_quant_pspec_fn(mdl_vars, prng_key)
# pylint: disable=g-long-lambda
new_pspec = jax.tree_map(
new_pspec = jax.tree.map(
lambda x: x.meta
if isinstance(x, base_layer.BoxedPartitionSpec)
else x,
Expand All @@ -607,7 +607,7 @@ def quant_fn(mdl_vars_to_quant, prng_keys):
or model_p.dtype == jnp.bfloat16
):
# Convert float inputs/vars if fprop dtype is bfloat16.
mdl_vars_to_quant = jax.tree_map(convert, mdl_vars_to_quant)
mdl_vars_to_quant = jax.tree.map(convert, mdl_vars_to_quant)
k1, k2, prng_keys = jax.random.split(prng_keys, num=3)
return jax_task.model.apply(
mdl_vars_to_quant,
Expand All @@ -631,8 +631,8 @@ def quant_fn(mdl_vars_to_quant, prng_keys):
model = new_jax_task.model
task_p = new_task_p
# TODO(jianlijianli): Get unpadded_shapes properly.
mdl_var_unpadded_shapes = jax.tree_map(lambda x: x.shape, mdl_vars)
mdl_var_unpadded_types = jax.tree_map(lambda x: x.dtype, mdl_vars)
mdl_var_unpadded_shapes = jax.tree.map(lambda x: x.shape, mdl_vars)
mdl_var_unpadded_types = jax.tree.map(lambda x: x.dtype, mdl_vars)
logging.info('quantized vars pspec %s', new_pspec)
logging.info('quantized vars shapes %s', mdl_var_unpadded_shapes)
logging.info('quantized vars types %s', mdl_var_unpadded_types)
Expand Down

0 comments on commit 8c167c3

Please sign in to comment.