Skip to content

Commit

Permalink
rename pmap axis_name='device'
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553132626
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Aug 2, 2023
1 parent cdf91ee commit 8d46173
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions kauldron/train/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def flatten(self) -> list[EvaluatorBase]:

@functools.partial(
jax.pmap,
axis_name='batch',
axis_name='device',
static_broadcasted_argnums=(0, 1),
)
def _pstep(
Expand All @@ -221,7 +221,7 @@ def _pstep(
params=state.params,
batch=batch,
rngs=rng_streams.eval_rngs(
eval_step, device_id=jax.lax.axis_index('batch')
eval_step, device_id=jax.lax.axis_index('device')
),
step=state.step, # Step is train step, NOT eval
is_training=False,
Expand Down
6 changes: 3 additions & 3 deletions kauldron/train/rngs_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make(
rng: The root rng, common to all processes
step: Current model step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('batch')`)
`jax.lax.axis_index('device')`)
key: Additional string (e.g. `train`, `init`,...) to fold in
Returns:
Expand Down Expand Up @@ -175,7 +175,7 @@ def train_rngs(self, step: int, *, device_id: int) -> Rngs:
Args:
step: Current train/eval step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('batch')`)
`jax.lax.axis_index('device')`)
Returns:
rngs: The `dict[<stream name>, kd.random.PRNGKey]`
Expand All @@ -198,7 +198,7 @@ def eval_rngs(self, step: int, *, device_id: int) -> Rngs:
Args:
step: Current train/eval step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('batch')`)
`jax.lax.axis_index('device')`)
Returns:
rngs: The `dict[<stream name>, kd.random.PRNGKey]`
Expand Down
6 changes: 3 additions & 3 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def step(

@functools.partial(
jax.pmap,
axis_name="batch",
axis_name="device",
static_broadcasted_argnums=(0, 3, 4, 5),
donate_argnums=(1,),
)
Expand All @@ -296,12 +296,12 @@ def _step(
state.params,
batch=batch,
rngs=self.rng_streams.train_rngs(
state.step, device_id=jax.lax.axis_index("batch")
state.step, device_id=jax.lax.axis_index("device")
),
step=state.step,
is_training=True,
)
grads = jax.lax.pmean(grads, axis_name="batch")
grads = jax.lax.pmean(grads, axis_name="device")
updates, new_opt_state = self.optimizer.update(
grads, state.opt_state, state.params
)
Expand Down

0 comments on commit 8d46173

Please sign in to comment.