From ce08fc1e47e08d846245d5117197d28b0cf5732e Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Wed, 2 Aug 2023 07:06:20 -0700 Subject: [PATCH] rename pmap `axis_name='device'` PiperOrigin-RevId: 553134090 --- kauldron/train/evaluators.py | 4 ++-- kauldron/train/rngs_lib.py | 6 +++--- kauldron/train/train_step.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/kauldron/train/evaluators.py b/kauldron/train/evaluators.py index e68173bb..a67fca88 100644 --- a/kauldron/train/evaluators.py +++ b/kauldron/train/evaluators.py @@ -206,7 +206,7 @@ def flatten(self) -> list[EvaluatorBase]: @functools.partial( jax.pmap, - axis_name='batch', + axis_name='device', static_broadcasted_argnums=(0, 1), ) def _pstep( @@ -221,7 +221,7 @@ def _pstep( params=state.params, batch=batch, rngs=rng_streams.eval_rngs( - eval_step, device_id=jax.lax.axis_index('batch') + eval_step, device_id=jax.lax.axis_index('device') ), step=state.step, # Step is train step, NOT eval is_training=False, diff --git a/kauldron/train/rngs_lib.py b/kauldron/train/rngs_lib.py index efb85f57..99c48054 100644 --- a/kauldron/train/rngs_lib.py +++ b/kauldron/train/rngs_lib.py @@ -68,7 +68,7 @@ def make( rng: The root rng, common to all processes step: Current model step device_id: Indicate be the device / axis id inside `pmap` (e.g. - `jax.lax.axis_index('batch')`) + `jax.lax.axis_index('device')`) key: Additional string (e.g. `train`, `init`,...) to fold in Returns: @@ -175,7 +175,7 @@ def train_rngs(self, step: int, *, device_id: int) -> Rngs: Args: step: Current train/eval step device_id: Indicate be the device / axis id inside `pmap` (e.g. - `jax.lax.axis_index('batch')`) + `jax.lax.axis_index('device')`) Returns: rngs: The `dict[, kd.random.PRNGKey]` @@ -198,7 +198,7 @@ def eval_rngs(self, step: int, *, device_id: int) -> Rngs: Args: step: Current train/eval step device_id: Indicate be the device / axis id inside `pmap` (e.g. - `jax.lax.axis_index('batch')`) + `jax.lax.axis_index('device')`) Returns: rngs: The `dict[, kd.random.PRNGKey]` diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index c58d8a45..71335f57 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -276,7 +276,7 @@ def step( @functools.partial( jax.pmap, - axis_name="batch", + axis_name="device", static_broadcasted_argnums=(0, 3, 4, 5), donate_argnums=(1,), ) @@ -296,12 +296,12 @@ def _step( state.params, batch=batch, rngs=self.rng_streams.train_rngs( - state.step, device_id=jax.lax.axis_index("batch") + state.step, device_id=jax.lax.axis_index("device") ), step=state.step, is_training=True, ) - grads = jax.lax.pmean(grads, axis_name="batch") + grads = jax.lax.pmean(grads, axis_name="device") updates, new_opt_state = self.optimizer.update( grads, state.opt_state, state.params )