diff --git a/pysages/ml/utils.py b/pysages/ml/utils.py index b3a796a5..f3633cb1 100644 --- a/pysages/ml/utils.py +++ b/pysages/ml/utils.py @@ -4,7 +4,6 @@ from jax import numpy as np from jax import random, vmap from jax._src.nn import initializers -from jax.core import as_named_shape from jax.numpy.linalg import norm from jax.tree_util import PyTreeDef, tree_flatten from numpy import cumsum @@ -12,6 +11,7 @@ from pysages.typing import NamedTuple from pysages.utils import identity, prod +from pysages.utils.compat import canonicalize_shape # Dispatcher for the `ml` submodule dispatch = Dispatcher() @@ -86,17 +86,17 @@ def uniform_scaling( raise ValueError(f"invalid mode for variance scaling initializer: {mode}") if bias_like: - trim_named_shape = idem(lambda named_shp, shp, axis: as_named_shape(shp[axis:])) + trim_shape = idem(lambda cshp, shp, axis: canonicalize_shape(shp[axis:])) else: - trim_named_shape = idem(lambda named_shp, shp, axis: named_shp) + trim_shape = idem(lambda cshp, shp, axis: cshp) def init(key, shape, dtype=dtype): - args_named_shape = as_named_shape(shape) - named_shape = trim_named_shape(args_named_shape, shape, out_axis) + canonical_shape = canonicalize_shape(shape) + shape = trim_shape(canonical_shape, shape, out_axis) # pylint: disable-next=W0212 - fan_in, fan_out = initializers._compute_fans(args_named_shape, in_axis, out_axis) + fan_in, fan_out = initializers._compute_fans(canonical_shape, in_axis, out_axis) s = np.array(scale / denominator(fan_in, fan_out), dtype=dtype) - return random.uniform(key, named_shape, dtype, -1) * transform(s) + return random.uniform(key, shape, dtype, -1) * transform(s) return init diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index e4382646..48ef56a6 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -36,6 +36,22 @@ def prod(iterable, start=1): return result +# Compatibility for jax >=0.4.31 + +# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0431-july-29-2024 +if _jax_version_tuple < (0, 4, 31): + _jax_core = import_module("jax.core") + + def canonicalize_shape(shape): + return _jax_core.as_named_shape(shape) + +else: + _jax_core = import_module("jax._src.core") + + def canonicalize_shape(shape): + return _jax_core.canonicalize_shape(shape) + + # Compatibility for jax >=0.4.22 # https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0422-dec-13-2023