Skip to content

Commit

Permalink
rename pmap axis_name='device'
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553132626
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Aug 2, 2023
1 parent 6e08503 commit 6cefe31
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 17 deletions.
6 changes: 4 additions & 2 deletions kauldron/train/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def flatten(self) -> list[EvaluatorBase]:

@functools.partial(
jax.pmap,
axis_name='batch',
axis_name='device',
static_broadcasted_argnums=(0, 1),
)
def _pstep(
Expand All @@ -220,7 +220,9 @@ def _pstep(
_, ctx = model_with_aux.forward(
params=state.params,
batch=batch,
rngs=rng_streams.eval_rngs(eval_step),
rngs=rng_streams.eval_rngs(
eval_step, device_id=jax.lax.axis_index('device')
),
step=state.step, # Step is train step, NOT eval
is_training=False,
)
Expand Down
86 changes: 74 additions & 12 deletions kauldron/train/rngs_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class RngStream:
eval: Whether the rng is used in eval (`is_training=False`)
per_step: Whether the rng is different at each step
per_process: Whether the rng is different for each process
per_device: Whether the rng is different for each device (in pmap)
"""

name: str
Expand All @@ -50,21 +51,48 @@ class RngStream:
eval: bool = False

per_step: bool = True
per_device: bool = True
per_process: bool = True

def make(
self, rng: kd_random.PRNGKey, *, step: int, key: str | None = None
self,
rng: kd_random.PRNGKey,
*,
step: int | None = None,
device_id: int | None = None,
key: str | None = None,
) -> kd_random.PRNGKey:
"""Create the `rng` from the global root rng."""
"""Create the `rng` from the global root rng.
Arguments:
rng: The root rng, common to all processes
step: Current model step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('device')`)
key: Additional string (e.g. `train`, `init`,...) to fold in
Returns:
The new rng
"""
rng = rng.fold_in(self.name)
if self.per_step:
self._assert_is_not_none(step, 'step')
rng = rng.fold_in(step)
if self.per_process:
rng = rng.fold_in(jax.process_index())
if self.per_device:
self._assert_is_not_none(device_id, 'device_id')
rng = rng.fold_in(device_id)
if key is not None: # Additional key to fold (e.g. `train`, `eval`)
rng = rng.fold_in(key)
return rng

def _assert_is_not_none(self, val, name: str) -> None:
if val is None:
raise ValueError(
f'Missing kwargs `{name}` to generate rng stream: {self}'
)


_DEFAULT_STREAMS = [
RngStream(
Expand All @@ -73,6 +101,7 @@ def make(
train=False,
eval=False,
per_step=False,
per_device=False,
per_process=False,
),
RngStream('dropout'),
Expand Down Expand Up @@ -128,26 +157,59 @@ def root_rng(self) -> kd_random.PRNGKey:
@_jit_method
def init_rngs(self) -> Rngs:
"""Rngs for `model.init()`."""
return {
r.name: r.make(self.root_rng, step=0, key='init')
return { # pylint: disable=g-complex-comprehension
r.name: r.make(
self.root_rng,
step=0,
device_id=0, # Assume `model.init` is ran outside `pmap`. Safe ?
key='init',
)
for r in self.streams.values()
if r.init
}

@_jit_method
def train_rngs(self, step: int) -> Rngs:
"""Rngs for `model.apply(..., is_training_property=True)`."""
return {
r.name: r.make(self.root_rng, step=step, key='train')
def train_rngs(self, step: int, *, device_id: int) -> Rngs:
"""Rngs for `model.apply(..., is_training_property=True)`.
Args:
step: Current train/eval step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('device')`)
Returns:
rngs: The `dict[<stream name>, kd.random.PRNGKey]`
"""
return { # pylint: disable=g-complex-comprehension
r.name: r.make(
self.root_rng,
step=step,
device_id=device_id,
key='train',
)
for r in self.streams.values()
if r.train
}

@_jit_method
def eval_rngs(self, step: int) -> Rngs:
"""Rngs for `model.apply(..., is_training_property=False)`."""
return {
r.name: r.make(self.root_rng, step=step, key='eval')
def eval_rngs(self, step: int, *, device_id: int) -> Rngs:
"""Rngs for `model.apply(..., is_training_property=False)`.
Args:
step: Current train/eval step
device_id: Indicate be the device / axis id inside `pmap` (e.g.
`jax.lax.axis_index('device')`)
Returns:
rngs: The `dict[<stream name>, kd.random.PRNGKey]`
"""
return { # pylint: disable=g-complex-comprehension
r.name: r.make(
self.root_rng,
step=step,
device_id=device_id,
key='eval',
)
for r in self.streams.values()
if r.eval
}
8 changes: 5 additions & 3 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def step(

@functools.partial(
jax.pmap,
axis_name="batch",
axis_name="device",
static_broadcasted_argnums=(0, 3, 4, 5),
donate_argnums=(1,),
)
Expand All @@ -295,11 +295,13 @@ def _step(
grads, context = grad_fn(
state.params,
batch=batch,
rngs=self.rng_streams.train_rngs(state.step),
rngs=self.rng_streams.train_rngs(
state.step, device_id=jax.lax.axis_index("device")
),
step=state.step,
is_training=True,
)
grads = jax.lax.pmean(grads, axis_name="batch")
grads = jax.lax.pmean(grads, axis_name="device")
updates, new_opt_state = self.optimizer.update(
grads, state.opt_state, state.params
)
Expand Down

0 comments on commit 6cefe31

Please sign in to comment.