From 65b1c7e4eb870af81ba6488b05088a726a8dcd02 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 18 Dec 2024 10:19:18 -0500 Subject: [PATCH] [nnx] jit cache --- benchmarks/nnx_simple_training.py | 63 +++++--- flax/nnx/graph.py | 248 +++++++++++++++++++++++++++-- flax/nnx/nn/stochastic.py | 3 + flax/nnx/rnglib.py | 6 +- flax/nnx/transforms/compilation.py | 33 +++- tests/nnx/graph_utils_test.py | 39 +++++ tests/nnx/transforms_test.py | 16 +- uv.lock | 86 +++++----- 8 files changed, 406 insertions(+), 88 deletions(-) diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py index 0cb08066fe..cde86c4e96 100644 --- a/benchmarks/nnx_simple_training.py +++ b/benchmarks/nnx_simple_training.py @@ -25,7 +25,9 @@ from absl import app FLAGS = flags.FLAGS -flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_enum( + 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' +) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') @@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): def __call__(self, x): return x @ self.w + self.b +class Block(nnx.Module): + def __init__(self, din, dhidden, *, rngs: nnx.Rngs): + self.linear = Linear(din, dhidden, rngs=rngs) + self.bn = nnx.BatchNorm(dhidden, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) class Count(nnx.Variable): pass @@ -54,11 +63,11 @@ class Count(nnx.Variable): class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) - self.linear_in = Linear(din, dhidden, rngs=rngs) + self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ - Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) + Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] - self.linear_out = Linear(dhidden, dout, rngs=rngs) + self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 @@ -79,18 +88,14 @@ def main(argv): print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') - if mode not in ['nnx', 'jax']: - raise ValueError(f'Invalid mode: {mode}') - X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) - tx = optax.sgd(1e-3) - optimizer = nnx.Optimizer(model, tx) - t0 = time() - - if mode == 'nnx': + if mode == 'nnx' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() @nnx.jit def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): @@ -110,16 +115,30 @@ def test_step_nnx(model: MLP, batch): loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} + print('### NNX ###') for step, batch in enumerate(dataset(X, Y, batch_size)): train_step_nnx(model, optimizer, batch) if step % 1000 == 0: logs = test_step_nnx(model, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break - else: + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + + print(f"step: {step}, loss: {logs['loss']}") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + + if mode == 'jax' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() @jax.jit def train_step_jax(graphdef, state, batch): @@ -146,22 +165,26 @@ def test_step_jax(graphdef, state, batch): graphdef, state = nnx.split((model, optimizer)) + print('### JAX ###') for step, batch in enumerate(dataset(X, Y, batch_size)): state = train_step_jax(graphdef, state, batch) if step % 1000 == 0: state, logs = test_step_jax(graphdef, state, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break model, optimizer = nnx.merge(graphdef, state) - total_time = time() - t0 - print('total time:', total_time) - print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + + print(f"step: {step}, loss: {logs['loss']}") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') if __name__ == '__main__': diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index a29999d34f..0b97743f75 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -19,6 +19,7 @@ import functools import threading import typing as tp +from weakref import WeakKeyDictionary import jax import numpy as np @@ -50,7 +51,7 @@ Leaf = tp.TypeVar('Leaf') AuxData = tp.TypeVar('AuxData') -StateLeaf = VariableState[tp.Any] +StateLeaf = tp.Union[VariableState[tp.Any], Variable[tp.Any]] NodeLeaf = Variable[tp.Any] GraphState = State[Key, StateLeaf] @@ -72,6 +73,9 @@ def __init__( self._mapping: dict[int, tuple[A, B]] = {} self.update(mapping) + def copy(self) -> RefMap[A, B]: + return RefMap(self) + def __getitem__(self, key: A) -> B: return self._mapping[id(key)][1] @@ -388,7 +392,11 @@ def _apply( def flatten( - node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None + node: Node, + /, + *, + ref_index: RefMap[tp.Any, Index] | None = None, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[Node], GraphState]: """Flattens a graph node into a (graphdef, state) pair. @@ -400,15 +408,76 @@ def flatten( """ if ref_index is None: ref_index = RefMap() - flat_state: list[tuple[PathParts, StateLeaf]] = [] - graphdef = _graph_flatten((), ref_index, flat_state, node) - return graphdef, GraphState.from_flat_path(flat_state) + + if node in ref_index: + return NodeRef(type(node), ref_index[node]), State({}) + + # main flatten function + def do_flatten(*, return_variables: bool): + assert ref_index is not None + flat_state: list[tuple[PathParts, StateLeaf]] = [] + graphdef = _graph_flatten((), ref_index, flat_state, return_variables, node) + return graphdef, GraphState.from_flat_path(flat_state) + + # cache logic + if cache_context is None: + graphdef, state = do_flatten(return_variables=False) + else: # cache_context is not None + if node in cache_context: + node_cache = cache_context[node] + cache_fp = node_cache.fingerprint + prev_ref_index = ref_index.copy() + node_fp = fingerprint(node, ref_index=ref_index) + if cache_fp == node_fp: + graphdef = node_cache.graphdef + state = jax.tree.map( + lambda v: v.to_state(), + node_cache.variables, + is_leaf=lambda x: isinstance(x, Variable), + ) + else: # cache_fp != current_fp: + index_ref_diff = { + index: ref + for ref, index in ref_index.items() + if ref not in prev_ref_index + } + ref_index = prev_ref_index # reset ref_index before calling do_flatten + graphdef, variables = do_flatten(return_variables=True) + state = jax.tree.map( + lambda v: v.to_state(), + variables, + is_leaf=lambda x: isinstance(x, Variable), + ) + cache_context[node] = CacheContext( + node_fp, graphdef, variables, index_ref_diff + ) + else: # node not in cache_context + prev_ref_index = ref_index.copy() + node_fp = fingerprint(node, ref_index=ref_index) + index_ref_diff = { + index: ref + for ref, index in ref_index.items() + if ref not in prev_ref_index + } + ref_index = prev_ref_index # reset ref_index before calling do_flatten + graphdef, variables = do_flatten(return_variables=True) + state = jax.tree.map( + lambda v: v.to_state(), + variables, + is_leaf=lambda x: isinstance(x, Variable), + ) + cache_context[node] = CacheContext( + node_fp, graphdef, variables, index_ref_diff + ) + + return graphdef, state def _graph_flatten( path: PathParts, ref_index: RefMap[tp.Any, Index], flat_state: list[tuple[PathParts, StateLeaf]], + return_variable: bool, node: Node, ) -> NodeDef[Node] | NodeRef: if not is_node(node): @@ -431,7 +500,9 @@ def _graph_flatten( values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): - nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) + nodedef = _graph_flatten( + (*path, key), ref_index, flat_state, return_variable, value + ) # subgraphs.append((key, nodedef)) attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): @@ -440,7 +511,8 @@ def _graph_flatten( LeafAttribute(key, NodeRef(type(value), ref_index[value])) ) else: - flat_state.append(((*path, key), value.to_state())) + state_leaf = value if return_variable else value.to_state() + flat_state.append(((*path, key), state_leaf)) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value._var_metadata) @@ -464,6 +536,70 @@ def _graph_flatten( ) return nodedef +def fingerprint( + node: Node, /, *, ref_index: RefMap[tp.Any, Index] | None = None +) -> tuple[tp.Any, ...]: + """ """ + if ref_index is None: + ref_index = RefMap() + fp = _graph_fingerprint(node, ref_index) + return fp + + +def _graph_fingerprint( + node, + ref_index: RefMap[tp.Any, Index], +) -> tuple[tp.Any, ...]: + if not is_node(node): + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + + if node in ref_index: + return (type(node), ref_index[node]) + + node_impl = get_node_impl(node) + + # only cache graph nodes + if isinstance(node_impl, GraphNodeImpl): + index = len(ref_index) + ref_index[node] = index + else: + index = -1 + + attributes: list[tuple[tp.Any, ...]] = [] + + values, metadata = node_impl.flatten(node) + for key, value in values: + if is_node(value): + node_fp = _graph_fingerprint(value, ref_index) + # subgraphs.append((key, nodedef)) + attributes.append((key, node_fp)) + elif isinstance(value, Variable): + if value in ref_index: + attributes.append((key, (type(value), ref_index[value]))) + else: + variable_index = ref_index[value] = len(ref_index) + # the fingerprint must be sensitive to Variable identity + variable_fp = ( + id(value), + type(value), + variable_index, + tuple(value._var_metadata.items()), + ) + attributes.append((key, variable_fp)) + else: + if isinstance(value, (jax.Array, np.ndarray)): + raise ValueError(f'Arrays leaves are not supported: {value}') + # static_fields.append((key, value)) + attributes.append((key, value)) + + node_fp = ( + node_impl.type, + index, + tuple(attributes), + metadata, + ) + return node_fp + def unflatten( graphdef: GraphDef[Node], @@ -472,6 +608,7 @@ def unflatten( *, index_ref: dict[Index, tp.Any] | None = None, index_ref_cache: dict[Index, tp.Any] | None = None, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -488,14 +625,60 @@ def unflatten( existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ - if isinstance(state, State): - state = state.raw_mapping # type: ignore + if index_ref is None: index_ref = {} - assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) + + if isinstance(graphdef, NodeRef): + return index_ref[graphdef.index] + + assert isinstance(graphdef, NodeDef) + + def do_unflatten(): + _state: tp.Mapping[KeyT, tp.Any] = state + if isinstance(_state, State): + _state = _state.raw_mapping # type: ignore + node = _graph_unflatten(graphdef, _state, index_ref, index_ref_cache) + return node + + if cache_context is None: + node = do_unflatten() + else: # cache_context is not None + if index_ref_cache is None: + raise ValueError( + 'index_ref_cache must be provided when cache_context is used.' + ) + if graphdef.index in index_ref_cache: + node = index_ref_cache[graphdef.index] + if node in cache_context: + # node is in cache_context, retrieve its cache + cache = cache_context[node] + assert graphdef.index_mapping is not None + + # check if the graphdef is the same + graphdef_fp = dataclasses.replace(graphdef, index_mapping=None) + if cache.graphdef == graphdef_fp and all( + a == b for a, b in graphdef.index_mapping.items() + ): + # graphdefs match, update variables from state + def _update_variables(variable: Variable, state: VariableState): + variable.raw_value = state.value + + jax.tree.map(_update_variables, cache.variables, state) + index_ref.update(cache.index_ref_diff) + else: # cache.graphdef != graphdef_fp + # graphdef changed, re-create the node + node = do_unflatten() + else: # node not in cache_context + # all nodes in index_ref_cache must be in cache_context + raise RuntimeError(f'Node not found in cache_context, node: {node}') + else: # graphdef.index not in index_ref_cache + # its a new node, create it + node = do_unflatten() + return node + def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], @@ -773,6 +956,13 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): # UpdateContext # -------------------------------------------------------- +class CacheContext(tp.NamedTuple): + fingerprint: tuple + graphdef: GraphDef + variables: State + index_ref_diff: dict[Index, tp.Any] + + @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( @@ -780,6 +970,9 @@ class GraphContext(threading.local): ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) + cache_context: WeakKeyDictionary[ + tp.Callable, WeakKeyDictionary[tp.Any, CacheContext] + ] = dataclasses.field(default_factory=WeakKeyDictionary) GRAPH_CONTEXT = GraphContext() @@ -791,10 +984,19 @@ class SplitContext: ref_index: RefMap[tp.Any, Index] @tp.overload - def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... + def split( + self, + graph_node: A, + /, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, + ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( - self, graph_node: A, first: filterlib.Filter, / + self, + graph_node: A, + first: filterlib.Filter, + /, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( @@ -804,14 +1006,20 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... def split( - self, node: A, *filters: filterlib.Filter + self, + node: A, + *filters: filterlib.Filter, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, state = flatten(node, self.ref_index) + graphdef, state = flatten( + node, ref_index=self.ref_index, cache_context=cache_context + ) states = _split_state(state, filters) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): @@ -846,7 +1054,12 @@ class MergeContext: index_ref: dict[Index, tp.Any] def merge( - self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState + self, + graphdef: GraphDef[A], + state: GraphState, + /, + *states: GraphState, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None @@ -871,6 +1084,7 @@ def merge( state, index_ref=self.index_ref, index_ref_cache=index_ref_cache, + cache_context=cache_context, ) return node @@ -1001,7 +1215,7 @@ def split( filters are passed, a single :class:`State` is returned. """ ref_index: RefMap[tp.Any, Index] = RefMap() - graphdef, state = flatten(node, ref_index) + graphdef, state = flatten(node, ref_index=ref_index) states = _split_state(state, filters) if self.index_ref is not None and isinstance(graphdef, NodeDef): diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index 2a495826a4..737c6e3102 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -125,3 +125,6 @@ def __call__( mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + def __hash__(self): + return id(self) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 17bbaf37c8..bc4b551972 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -80,7 +80,7 @@ def __call__(self) -> jax.Array: ] -class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]): +class Rngs(Object): """NNX rng container class. To instantiate the ``Rngs``, pass in an integer, specifying the starting seed. ``Rngs`` can have different "streams", allowing the user to generate different @@ -237,6 +237,10 @@ def __getstate__(self): def __setstate__(self, state): vars(self).update(state) + def items(self): + for name in self: + yield name, self[name] + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index e5ce20f8e3..aa5e7773af 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -16,6 +16,7 @@ import dataclasses import functools import typing as tp +from weakref import WeakKeyDictionary from flax.nnx import ( extract, @@ -88,7 +89,7 @@ def __hash__(self): return hash((self.filters, self.shardings)) -def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): +def _inner_jit_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): return extract.NodeStates.from_split( *ctx.split(x, *prefix.filters), metadata=prefix @@ -116,7 +117,7 @@ def __call__(self, *pure_args, **pure_kwargs): (args_out, kwargs_out, out), prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings), ctxtag='jit', - split_fn=_jit_split_fn, + split_fn=_inner_jit_split_fn, ) return pure_args_out, pure_kwargs_out, pure_out @@ -335,10 +336,32 @@ def jit( @functools.wraps(fun) @graph.update_context('jit') def jit_wrapper(*args, **kwargs): + if jit_wrapper not in graph.GRAPH_CONTEXT.cache_context: + graph.GRAPH_CONTEXT.cache_context[jit_wrapper] = WeakKeyDictionary() + jit_cache = graph.GRAPH_CONTEXT.cache_context[jit_wrapper] + + def _outer_jit_split_fn(ctx: graph.SplitContext, path, prefix, x): + if isinstance(prefix, StateSharding): + return extract.NodeStates.from_split( + *ctx.split(x, *prefix.filters, cache_context=jit_cache), + metadata=prefix, + ) + return extract.NodeStates.from_split( + *ctx.split(x, cache_context=jit_cache) + ) + + def _outer_jit_merge_fn( + ctx: graph.MergeContext, path, prefix, leaf + ) -> tp.Any: + if not isinstance(leaf, extract.NodeStates): + raise ValueError( + f'Expected NodeStates, got {type(leaf)} at path {path}' + ) + return ctx.merge(leaf.graphdef, *leaf.states, cache_context=jit_cache) pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(in_shardings, kwarg_shardings), - split_fn=_jit_split_fn, + split_fn=_outer_jit_split_fn, check_aliasing=in_shardings is not None, ctxtag='jit', ) @@ -346,7 +369,9 @@ def jit_wrapper(*args, **kwargs): *pure_args, **pure_kwargs ) _args_out, _kwargs_out, out = extract.from_tree( - (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit' + (pure_args_out, pure_kwargs_out, pure_out), + merge_fn=_outer_jit_merge_fn, + ctxtag='jit', ) return out diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index a7bbf178cb..c566494662 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -793,6 +793,45 @@ def f(*pure_args): self.assertIs(m1, args_out[2]['b']) self.assertIs(m2, args_out[1]) + def test_fingerprint_basic(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m) + m1_hash = hash(fp1) + self.assertIsInstance(m1_hash, int) + + fp2 = nnx.graph.fingerprint(m) + m2_hash = hash(fp2) + + self.assertEqual(fp1, fp2) + self.assertEqual(m1_hash, m2_hash) + + def test_fingerprint_variable_id_sensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m1) + m1_hash = hash(fp1) + + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp2 = nnx.graph.fingerprint(m2) + m2_hash = hash(fp2) + + self.assertNotEqual(fp1, fp2) + self.assertNotEqual(m1_hash, m2_hash) + + def test_fingerprint_module_id_insensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + m1.kernel = m2.kernel + m1.bias = m2.bias + + fp1 = nnx.graph.fingerprint(m1) + m1_hash = hash(fp1) + fp2 = nnx.graph.fingerprint(m2) + m2_hash = hash(fp2) + + self.assertEqual(fp1, fp2) + self.assertEqual(m1_hash, m2_hash) + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 736da9acf0..d1aed76546 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -50,7 +50,7 @@ def __setitem__(self, key, value): if tp.TYPE_CHECKING: - def __getattr__(self, key): ... + def __getattr__(self, key) -> tp.Any: ... class TestJIT(absltest.TestCase): @@ -59,12 +59,16 @@ def test_jit(self): @nnx.jit def g(m: Dict): - m.a = 2 + m.a.value += 1 return 1.0 out = g(m) - assert m.a == 2 + assert m.a.value == 2 + assert out == 1.0 + + out = g(m) + assert m.a.value == 3 assert out == 1.0 def test_jit_on_init(self): @@ -715,6 +719,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self) -> int: + return id(self) + @nnx.custom_vjp @nnx.remat def f(m: Foo): @@ -3081,6 +3088,9 @@ def test_basic(self): class Foo(nnx.Module): a: nnx.Param + def __hash__(self): + return id(self) + @nnx.jit def f(m): y = jnp.sin(m.a.value) # error diff --git a/uv.lock b/uv.lock index e08e2dbf53..29b00dd3b3 100644 --- a/uv.lock +++ b/uv.lock @@ -3,13 +3,13 @@ requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] [[package]] @@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 } wheels = [ @@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 } wheels = [ @@ -1202,7 +1202,7 @@ name = "ipython" version = "8.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1246,7 +1246,7 @@ wheels = [ [[package]] name = "jax" -version = "0.4.37" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1255,14 +1255,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 } +sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 }, + { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 }, ] [[package]] name = "jaxlib" -version = "0.4.36" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1270,26 +1270,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 }, - { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 }, - { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 }, - { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 }, - { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 }, - { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 }, - { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 }, - { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 }, - { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 }, - { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 }, - { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 }, - { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 }, - { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 }, - { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 }, - { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 }, - { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 }, - { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 }, - { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 }, - { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 }, - { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 }, + { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 }, + { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 }, + { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 }, + { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 }, + { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 }, + { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 }, + { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 }, + { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 }, + { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 }, + { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 }, + { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 }, + { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 }, + { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 }, + { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 }, + { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 }, + { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 }, + { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 }, + { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 }, + { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 }, + { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 }, ] [[package]] @@ -1431,7 +1431,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ @@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 } wheels = [ @@ -2606,7 +2606,7 @@ name = "pytest" version = "8.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -3684,7 +3684,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },