Skip to content

Commit

Permalink
Compat option for jax._src.core.canonicalize_shape (SSAGESLabs#340)
Browse files Browse the repository at this point in the history
This prevents `pysages` from being imported when `jax>=0.4.31`.
  • Loading branch information
pabloferz authored Sep 27, 2024
1 parent ed8d53b commit 00d4bca
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pysages/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from jax import numpy as np
from jax import random, vmap
from jax._src.nn import initializers
from jax.core import as_named_shape
from jax.numpy.linalg import norm
from jax.tree_util import PyTreeDef, tree_flatten
from numpy import cumsum
from plum import Dispatcher

from pysages.typing import NamedTuple
from pysages.utils import identity, prod
from pysages.utils.compat import canonicalize_shape

# Dispatcher for the `ml` submodule
dispatch = Dispatcher()
Expand Down Expand Up @@ -86,17 +86,17 @@ def uniform_scaling(
raise ValueError(f"invalid mode for variance scaling initializer: {mode}")

if bias_like:
trim_named_shape = idem(lambda named_shp, shp, axis: as_named_shape(shp[axis:]))
trim_shape = idem(lambda cshp, shp, axis: canonicalize_shape(shp[axis:]))
else:
trim_named_shape = idem(lambda named_shp, shp, axis: named_shp)
trim_shape = idem(lambda cshp, shp, axis: cshp)

def init(key, shape, dtype=dtype):
args_named_shape = as_named_shape(shape)
named_shape = trim_named_shape(args_named_shape, shape, out_axis)
canonical_shape = canonicalize_shape(shape)
shape = trim_shape(canonical_shape, shape, out_axis)
# pylint: disable-next=W0212
fan_in, fan_out = initializers._compute_fans(args_named_shape, in_axis, out_axis)
fan_in, fan_out = initializers._compute_fans(canonical_shape, in_axis, out_axis)
s = np.array(scale / denominator(fan_in, fan_out), dtype=dtype)
return random.uniform(key, named_shape, dtype, -1) * transform(s)
return random.uniform(key, shape, dtype, -1) * transform(s)

return init

Expand Down
16 changes: 16 additions & 0 deletions pysages/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def prod(iterable, start=1):
return result


# Compatibility for jax >=0.4.31

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0431-july-29-2024
if _jax_version_tuple < (0, 4, 31):
_jax_core = import_module("jax.core")

def canonicalize_shape(shape):
return _jax_core.as_named_shape(shape)

else:
_jax_core = import_module("jax._src.core")

def canonicalize_shape(shape):
return _jax_core.canonicalize_shape(shape)


# Compatibility for jax >=0.4.22

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0422-dec-13-2023
Expand Down

0 comments on commit 00d4bca

Please sign in to comment.