Skip to content

Commit

Permalink
Unify way to copy values from the root cfg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553160113
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Aug 2, 2023
1 parent ce08fc1 commit 3507407
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 60 deletions.
1 change: 0 additions & 1 deletion kauldron/projects/examples/mnist_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def get_config():

# Checkpointer
cfg.checkpointer = kd.train.Checkpointer(
workdir=cfg.ref.workdir,
save_interval_steps=500,
)

Expand Down
1 change: 0 additions & 1 deletion kauldron/projects/examples/tiny_vit_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_config():

# Checkpointer
cfg.checkpointer = kd.train.Checkpointer(
workdir=cfg.ref.workdir,
save_interval_steps=1000,
max_to_keep=1,
)
Expand Down
1 change: 0 additions & 1 deletion kauldron/projects/examples/vit_b_i1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def get_config():

# Checkpointer
cfg.checkpointer = kd.train.Checkpointer(
workdir=cfg.ref.workdir,
save_interval_steps=1000,
max_to_keep=1,
)
Expand Down
16 changes: 9 additions & 7 deletions kauldron/train/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

import abc
# import abc
import dataclasses
import datetime
import functools
Expand All @@ -25,15 +25,17 @@
from etils import epath
from flax.training import orbax_utils
import jax
from kauldron.utils import config_util
import orbax.checkpoint as orbax

_T = TypeVar("_T")


class BaseCheckpointer(abc.ABC):
# TODO(epot): Why `abc` fail ?
class BaseCheckpointer(config_util.UpdateFromRootCfg): # , abc.ABC):
"""Basic checkpointing interface."""

@abc.abstractmethod
# @abc.abstractmethod
def restore(
self,
initial_state: _T,
Expand All @@ -43,11 +45,11 @@ def restore(
) -> _T:
raise NotImplementedError()

@abc.abstractmethod
# @abc.abstractmethod
def should_save(self, step: int) -> bool:
raise NotImplementedError()

@abc.abstractmethod
# @abc.abstractmethod
def save_state(
self,
state,
Expand All @@ -57,7 +59,7 @@ def save_state(
) -> bool:
raise NotImplementedError()

@abc.abstractmethod
# @abc.abstractmethod
def maybe_save_state(
self,
state,
Expand All @@ -79,7 +81,7 @@ def all_steps(self) -> Sequence[int]:
@dataclasses.dataclass(frozen=True, eq=True, kw_only=True)
class Checkpointer(BaseCheckpointer):
"""Basic Orbax Checkpointmanager."""
workdir: str | epath.Path
workdir: str | epath.Path = config_util.ROOT_CFG_REF.workdir

save_interval_steps: int
max_to_keep: Optional[int] = 3
Expand Down
16 changes: 7 additions & 9 deletions kauldron/train/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,12 @@ class Config(config_util.BaseConfig):
)

def __post_init__(self):
# Eventually propagate the seed from the root config
# Set rngs before eval, as the rngs is used in eval.
if self.rng_streams.seed is None:
object.__setattr__(
self,
'rng_streams',
dataclasses.replace(self.rng_streams, seed=self.seed),
)
# Some config object values are lazy-initialized from the root config.
# See `UpdateFromRootCfg` for details
object.__setattr__(
self, 'eval', dataclasses.replace(self.eval, base_cfg=self)
self, 'rng_streams', self.rng_streams.update_from_root_cfg(self)
)
object.__setattr__(self, 'eval', self.eval.update_from_root_cfg(self))
object.__setattr__(
self, 'checkpointer', self.checkpointer.update_from_root_cfg(self)
)
61 changes: 23 additions & 38 deletions kauldron/train/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from __future__ import annotations

import abc
import dataclasses
# import abc
import functools
import itertools
from typing import Any, Optional, TypeVar
Expand All @@ -42,7 +41,9 @@
_REUSE_TRAIN: Any = object()


class EvaluatorBase(config_util.BaseConfig, abc.ABC):
class EvaluatorBase(
config_util.BaseConfig, config_util.UpdateFromRootCfg # , abc.ABC
):
"""Evaluator interface.
Usage:
Expand All @@ -60,14 +61,12 @@ class EvaluatorBase(config_util.BaseConfig, abc.ABC):
base_cfg: Train config from which model, checkpoint,... are reused
"""

base_cfg: config_lib.Config

@abc.abstractmethod
# @abc.abstractmethod
def maybe_eval(self, *, step: int, state: train_step.TrainState):
"""Eventually evaluate the train state."""
raise NotImplementedError

@abc.abstractmethod
# @abc.abstractmethod
def flatten(self) -> list[EvaluatorBase]:
"""Iterate over the evaluator nodes."""
raise NotImplementedError
Expand Down Expand Up @@ -105,29 +104,17 @@ class SingleEvaluator(EvaluatorBase):
run_every: int
num_batches: Optional[int]
ds: data.TFDataPipeline
losses: dict[str, losses_lib.Loss] = _REUSE_TRAIN
metrics: dict[str, metrics_lib.Metric] = _REUSE_TRAIN
summaries: dict[str, summaries_lib.Summary] = _REUSE_TRAIN
rng_streams: rngs_lib.RngStreams = _REUSE_TRAIN

def __post_init__(self):
# Eventually copy the metrics from the train config
if hasattr(self, 'base_cfg'):
for name, base_name in [
('losses', 'train_losses'),
('metrics', 'train_metrics'),
('summaries', 'train_summaries'),
('rng_streams', 'rng_streams'),
]:
# TODO(klausg): filter out metrics / summaries that access grads/updates
if getattr(self, name) is _REUSE_TRAIN:
object.__setattr__(self, name, getattr(self.base_cfg, base_name))
if self.base_cfg.rng_streams is not self.rng_streams:
raise ValueError(
'RngStreams should be the same in eval / train. To use a stream in'
' eval, set `eval=True` in the `RngStream`.\n'
f'Got: {self.base_cfg.rng_streams} != {self.rng_streams}'
)
losses: dict[str, losses_lib.Loss] = config_util.ROOT_CFG_REF.train_losses
metrics: dict[str, metrics_lib.Metric] = (
config_util.ROOT_CFG_REF.train_metrics
)
summaries: dict[str, summaries_lib.Summary] = (
config_util.ROOT_CFG_REF.train_summaries
)

base_cfg: config_lib.Config = config_util.ROOT_CFG_REF

# TODO(klausg): filter out metrics / summaries that access grads/updates

def maybe_eval(
self, *, step: int, state: train_step.TrainState
Expand All @@ -154,7 +141,7 @@ def eval_step(
eval_step = flax.jax_utils.replicate(eval_step)
aux = _pstep(
self.model_with_aux,
self.rng_streams,
self.base_cfg.rng_streams,
eval_step,
state,
batch,
Expand Down Expand Up @@ -253,13 +240,11 @@ class MultiEvaluator(EvaluatorBase):

children: list[EvaluatorBase]

def __post_init__(self):
# Propagate the cfg to the children
if hasattr(self, 'base_cfg'):
children = [
dataclasses.replace(c, base_cfg=self.base_cfg) for c in self.children
]
object.__setattr__(self, 'children', children)
def update_from_root_cfg(self: _SelfT, root_cfg: config_lib.Config) -> _SelfT:
"""See base class."""
return self.replace(
children=[c.update_from_root_cfg(root_cfg) for c in self.children]
)

def maybe_eval(self, *, step: int, state: train_step.TrainState):
for evaluator in self.children:
Expand Down
5 changes: 3 additions & 2 deletions kauldron/train/rngs_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import jax
from kauldron import random as kd_random
from kauldron.utils import config_util

Rngs = dict[str, kd_random.PRNGKey]

Expand Down Expand Up @@ -110,7 +111,7 @@ def _assert_is_not_none(self, val, name: str) -> None:


@dataclasses.dataclass(frozen=True, eq=True)
class RngStreams:
class RngStreams(config_util.UpdateFromRootCfg):
"""Manager of rng streams.
Generate the `rngs` dict to pass to `model.init` / `model.apply`.
Expand All @@ -130,7 +131,7 @@ class RngStreams:
)

_: dataclasses.KW_ONLY
seed: int | None = None
seed: int = config_util.ROOT_CFG_REF.seed

@functools.cached_property
def streams(self) -> dict[str, RngStream]:
Expand Down
103 changes: 102 additions & 1 deletion kauldron/utils/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@

"""Utils for dataclasses."""

from __future__ import annotations

import dataclasses
import typing
from typing import Any, TypeVar

from etils import edc
from kauldron import konfig

if typing.TYPE_CHECKING:
from kauldron.train import config_lib

_SelfT = TypeVar('_SelfT')


Expand Down Expand Up @@ -80,7 +85,7 @@ def _repr_html_(self) -> str:
return konfig.ConfigDict(self._field_values)._repr_html_() # pylint: disable=protected-access

def replace(self: _SelfT, **changes: Any) -> _SelfT:
return type(self)(**self._field_values | changes)
return type(self)(**self._field_values | changes) # pylint: disable=protected-access

if typing.TYPE_CHECKING:

Expand All @@ -94,3 +99,99 @@ def _field_values(self) -> dict[str, Any]:
if hasattr(self, f.name):
new_values[f.name] = getattr(self, f.name)
return new_values


@dataclasses.dataclass(frozen=True)
class _FakeRootCfg:
"""Fake root config reference object.
See `UpdateFromRootCfg` for usage.
If the field is not set, the value will be copied from the root
`kd.train.Config` object, after it is created.
"""

parent: _FakeRootCfg | None = None
name: str = 'cfg'

def __getattr__(self, name: str) -> Any:
return _FakeRootCfg(parent=self, name=name)

@classmethod
def make_fake_cfg(cls) -> config_lib.Config:
return cls() # pytype: disable=bad-return-type

@property
def names(self) -> tuple[str, ...]:
names = []
curr = self
while curr is not None:
names.append(curr.name)
curr = curr.parent
return tuple(reversed(names))

def __repr__(self) -> str:
qualname = '.'.join(self.names)
return f'{type(self).__name__}({qualname!r})'

def __set_name__(self, owner, name):
if not issubclass(owner, UpdateFromRootCfg):
raise TypeError(
'`ROOT_CFG_REF` can only be assigned on subclasses of'
f' `UpdateFromRootCfg`.\nFor: {owner.__name__}.{name} = {self}'
)


ROOT_CFG_REF: config_lib.Config = _FakeRootCfg.make_fake_cfg()


@dataclasses.dataclass(frozen=True, eq=True, kw_only=True)
class UpdateFromRootCfg:
"""Allow child object to be updated with values from the base config.
For example:
* `Checkpointer` reuse the `workdir` from the base config.
* `Evaluator`, `RgnStreams` reuse the `seed` from the base config.
To use, either:
* Set your dataclass fields to `ROOT_CFG_REF.xxx` to specify the fields should
be copied from the base config.
* Overwrite the `update_from_root_cfg` method, for a custom initialization.
When using, make sure to also update the `kd.train.Config.__post_init__` to
call
`update_from_root_cfg`. Currently this not done automatically.
Example:
```python
@dataclasses.dataclass
class MyConfig:
workdir: epath.Path = ROOT_CFG_REF.workdir
```
Attributes:
_REUSE_FROM_ROOT_CFG: Mapping <root_cfg attribute> to <self attribute>
"""

def update_from_root_cfg(self: _SelfT, root_cfg: config_lib.Config) -> _SelfT:
"""Returns a copy of `self`, potentially with updated values."""
fields_to_replace = {}
for f in dataclasses.fields(self):
default = f.default
if not isinstance(default, _FakeRootCfg):
continue
value = getattr(self, f.name)
if not isinstance(value, _FakeRootCfg):
continue
# value is a fake cfg, should be update
new_value = root_cfg
for attr in value.names[1:]:
new_value = getattr(root_cfg, attr)
fields_to_replace[f.name] = new_value
if not fields_to_replace:
return self
else:
return dataclasses.replace(self, **fields_to_replace)

0 comments on commit 3507407

Please sign in to comment.