From 3507407e634d8aeeb87c785f8e305f0a68c52a1a Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Wed, 2 Aug 2023 08:56:32 -0700 Subject: [PATCH] Unify way to copy values from the root cfg PiperOrigin-RevId: 553160113 --- .../projects/examples/mnist_autoencoder.py | 1 - .../projects/examples/tiny_vit_imagenet.py | 1 - kauldron/projects/examples/vit_b_i1k.py | 1 - kauldron/train/checkpointer.py | 16 +-- kauldron/train/config_lib.py | 16 ++- kauldron/train/evaluators.py | 61 ++++------- kauldron/train/rngs_lib.py | 5 +- kauldron/utils/config_util.py | 103 +++++++++++++++++- 8 files changed, 144 insertions(+), 60 deletions(-) diff --git a/kauldron/projects/examples/mnist_autoencoder.py b/kauldron/projects/examples/mnist_autoencoder.py index a81f3908..59dd900f 100644 --- a/kauldron/projects/examples/mnist_autoencoder.py +++ b/kauldron/projects/examples/mnist_autoencoder.py @@ -86,7 +86,6 @@ def get_config(): # Checkpointer cfg.checkpointer = kd.train.Checkpointer( - workdir=cfg.ref.workdir, save_interval_steps=500, ) diff --git a/kauldron/projects/examples/tiny_vit_imagenet.py b/kauldron/projects/examples/tiny_vit_imagenet.py index 0a7ed944..a6a5419f 100644 --- a/kauldron/projects/examples/tiny_vit_imagenet.py +++ b/kauldron/projects/examples/tiny_vit_imagenet.py @@ -89,7 +89,6 @@ def get_config(): # Checkpointer cfg.checkpointer = kd.train.Checkpointer( - workdir=cfg.ref.workdir, save_interval_steps=1000, max_to_keep=1, ) diff --git a/kauldron/projects/examples/vit_b_i1k.py b/kauldron/projects/examples/vit_b_i1k.py index 17b8efdc..a28fc34f 100644 --- a/kauldron/projects/examples/vit_b_i1k.py +++ b/kauldron/projects/examples/vit_b_i1k.py @@ -111,7 +111,6 @@ def get_config(): # Checkpointer cfg.checkpointer = kd.train.Checkpointer( - workdir=cfg.ref.workdir, save_interval_steps=1000, max_to_keep=1, ) diff --git a/kauldron/train/checkpointer.py b/kauldron/train/checkpointer.py index cc6b6603..61232261 100644 --- a/kauldron/train/checkpointer.py +++ b/kauldron/train/checkpointer.py @@ -16,7 +16,7 @@ from __future__ import annotations -import abc +# import abc import dataclasses import datetime import functools @@ -25,15 +25,17 @@ from etils import epath from flax.training import orbax_utils import jax +from kauldron.utils import config_util import orbax.checkpoint as orbax _T = TypeVar("_T") -class BaseCheckpointer(abc.ABC): +# TODO(epot): Why `abc` fail ? +class BaseCheckpointer(config_util.UpdateFromRootCfg): # , abc.ABC): """Basic checkpointing interface.""" - @abc.abstractmethod + # @abc.abstractmethod def restore( self, initial_state: _T, @@ -43,11 +45,11 @@ def restore( ) -> _T: raise NotImplementedError() - @abc.abstractmethod + # @abc.abstractmethod def should_save(self, step: int) -> bool: raise NotImplementedError() - @abc.abstractmethod + # @abc.abstractmethod def save_state( self, state, @@ -57,7 +59,7 @@ def save_state( ) -> bool: raise NotImplementedError() - @abc.abstractmethod + # @abc.abstractmethod def maybe_save_state( self, state, @@ -79,7 +81,7 @@ def all_steps(self) -> Sequence[int]: @dataclasses.dataclass(frozen=True, eq=True, kw_only=True) class Checkpointer(BaseCheckpointer): """Basic Orbax Checkpointmanager.""" - workdir: str | epath.Path + workdir: str | epath.Path = config_util.ROOT_CFG_REF.workdir save_interval_steps: int max_to_keep: Optional[int] = 3 diff --git a/kauldron/train/config_lib.py b/kauldron/train/config_lib.py index b2a4e3ab..617e5159 100644 --- a/kauldron/train/config_lib.py +++ b/kauldron/train/config_lib.py @@ -105,14 +105,12 @@ class Config(config_util.BaseConfig): ) def __post_init__(self): - # Eventually propagate the seed from the root config - # Set rngs before eval, as the rngs is used in eval. - if self.rng_streams.seed is None: - object.__setattr__( - self, - 'rng_streams', - dataclasses.replace(self.rng_streams, seed=self.seed), - ) + # Some config object values are lazy-initialized from the root config. + # See `UpdateFromRootCfg` for details object.__setattr__( - self, 'eval', dataclasses.replace(self.eval, base_cfg=self) + self, 'rng_streams', self.rng_streams.update_from_root_cfg(self) + ) + object.__setattr__(self, 'eval', self.eval.update_from_root_cfg(self)) + object.__setattr__( + self, 'checkpointer', self.checkpointer.update_from_root_cfg(self) ) diff --git a/kauldron/train/evaluators.py b/kauldron/train/evaluators.py index a67fca88..6ee7ac1b 100644 --- a/kauldron/train/evaluators.py +++ b/kauldron/train/evaluators.py @@ -16,8 +16,7 @@ from __future__ import annotations -import abc -import dataclasses +# import abc import functools import itertools from typing import Any, Optional, TypeVar @@ -42,7 +41,9 @@ _REUSE_TRAIN: Any = object() -class EvaluatorBase(config_util.BaseConfig, abc.ABC): +class EvaluatorBase( + config_util.BaseConfig, config_util.UpdateFromRootCfg # , abc.ABC +): """Evaluator interface. Usage: @@ -60,14 +61,12 @@ class EvaluatorBase(config_util.BaseConfig, abc.ABC): base_cfg: Train config from which model, checkpoint,... are reused """ - base_cfg: config_lib.Config - - @abc.abstractmethod + # @abc.abstractmethod def maybe_eval(self, *, step: int, state: train_step.TrainState): """Eventually evaluate the train state.""" raise NotImplementedError - @abc.abstractmethod + # @abc.abstractmethod def flatten(self) -> list[EvaluatorBase]: """Iterate over the evaluator nodes.""" raise NotImplementedError @@ -105,29 +104,17 @@ class SingleEvaluator(EvaluatorBase): run_every: int num_batches: Optional[int] ds: data.TFDataPipeline - losses: dict[str, losses_lib.Loss] = _REUSE_TRAIN - metrics: dict[str, metrics_lib.Metric] = _REUSE_TRAIN - summaries: dict[str, summaries_lib.Summary] = _REUSE_TRAIN - rng_streams: rngs_lib.RngStreams = _REUSE_TRAIN - - def __post_init__(self): - # Eventually copy the metrics from the train config - if hasattr(self, 'base_cfg'): - for name, base_name in [ - ('losses', 'train_losses'), - ('metrics', 'train_metrics'), - ('summaries', 'train_summaries'), - ('rng_streams', 'rng_streams'), - ]: - # TODO(klausg): filter out metrics / summaries that access grads/updates - if getattr(self, name) is _REUSE_TRAIN: - object.__setattr__(self, name, getattr(self.base_cfg, base_name)) - if self.base_cfg.rng_streams is not self.rng_streams: - raise ValueError( - 'RngStreams should be the same in eval / train. To use a stream in' - ' eval, set `eval=True` in the `RngStream`.\n' - f'Got: {self.base_cfg.rng_streams} != {self.rng_streams}' - ) + losses: dict[str, losses_lib.Loss] = config_util.ROOT_CFG_REF.train_losses + metrics: dict[str, metrics_lib.Metric] = ( + config_util.ROOT_CFG_REF.train_metrics + ) + summaries: dict[str, summaries_lib.Summary] = ( + config_util.ROOT_CFG_REF.train_summaries + ) + + base_cfg: config_lib.Config = config_util.ROOT_CFG_REF + + # TODO(klausg): filter out metrics / summaries that access grads/updates def maybe_eval( self, *, step: int, state: train_step.TrainState @@ -154,7 +141,7 @@ def eval_step( eval_step = flax.jax_utils.replicate(eval_step) aux = _pstep( self.model_with_aux, - self.rng_streams, + self.base_cfg.rng_streams, eval_step, state, batch, @@ -253,13 +240,11 @@ class MultiEvaluator(EvaluatorBase): children: list[EvaluatorBase] - def __post_init__(self): - # Propagate the cfg to the children - if hasattr(self, 'base_cfg'): - children = [ - dataclasses.replace(c, base_cfg=self.base_cfg) for c in self.children - ] - object.__setattr__(self, 'children', children) + def update_from_root_cfg(self: _SelfT, root_cfg: config_lib.Config) -> _SelfT: + """See base class.""" + return self.replace( + children=[c.update_from_root_cfg(root_cfg) for c in self.children] + ) def maybe_eval(self, *, step: int, state: train_step.TrainState): for evaluator in self.children: diff --git a/kauldron/train/rngs_lib.py b/kauldron/train/rngs_lib.py index 99c48054..cf7ff066 100644 --- a/kauldron/train/rngs_lib.py +++ b/kauldron/train/rngs_lib.py @@ -22,6 +22,7 @@ import jax from kauldron import random as kd_random +from kauldron.utils import config_util Rngs = dict[str, kd_random.PRNGKey] @@ -110,7 +111,7 @@ def _assert_is_not_none(self, val, name: str) -> None: @dataclasses.dataclass(frozen=True, eq=True) -class RngStreams: +class RngStreams(config_util.UpdateFromRootCfg): """Manager of rng streams. Generate the `rngs` dict to pass to `model.init` / `model.apply`. @@ -130,7 +131,7 @@ class RngStreams: ) _: dataclasses.KW_ONLY - seed: int | None = None + seed: int = config_util.ROOT_CFG_REF.seed @functools.cached_property def streams(self) -> dict[str, RngStream]: diff --git a/kauldron/utils/config_util.py b/kauldron/utils/config_util.py index 8ee3cf67..42982f12 100644 --- a/kauldron/utils/config_util.py +++ b/kauldron/utils/config_util.py @@ -14,6 +14,8 @@ """Utils for dataclasses.""" +from __future__ import annotations + import dataclasses import typing from typing import Any, TypeVar @@ -21,6 +23,9 @@ from etils import edc from kauldron import konfig +if typing.TYPE_CHECKING: + from kauldron.train import config_lib + _SelfT = TypeVar('_SelfT') @@ -80,7 +85,7 @@ def _repr_html_(self) -> str: return konfig.ConfigDict(self._field_values)._repr_html_() # pylint: disable=protected-access def replace(self: _SelfT, **changes: Any) -> _SelfT: - return type(self)(**self._field_values | changes) + return type(self)(**self._field_values | changes) # pylint: disable=protected-access if typing.TYPE_CHECKING: @@ -94,3 +99,99 @@ def _field_values(self) -> dict[str, Any]: if hasattr(self, f.name): new_values[f.name] = getattr(self, f.name) return new_values + + +@dataclasses.dataclass(frozen=True) +class _FakeRootCfg: + """Fake root config reference object. + + See `UpdateFromRootCfg` for usage. + + If the field is not set, the value will be copied from the root + `kd.train.Config` object, after it is created. + """ + + parent: _FakeRootCfg | None = None + name: str = 'cfg' + + def __getattr__(self, name: str) -> Any: + return _FakeRootCfg(parent=self, name=name) + + @classmethod + def make_fake_cfg(cls) -> config_lib.Config: + return cls() # pytype: disable=bad-return-type + + @property + def names(self) -> tuple[str, ...]: + names = [] + curr = self + while curr is not None: + names.append(curr.name) + curr = curr.parent + return tuple(reversed(names)) + + def __repr__(self) -> str: + qualname = '.'.join(self.names) + return f'{type(self).__name__}({qualname!r})' + + def __set_name__(self, owner, name): + if not issubclass(owner, UpdateFromRootCfg): + raise TypeError( + '`ROOT_CFG_REF` can only be assigned on subclasses of' + f' `UpdateFromRootCfg`.\nFor: {owner.__name__}.{name} = {self}' + ) + + +ROOT_CFG_REF: config_lib.Config = _FakeRootCfg.make_fake_cfg() + + +@dataclasses.dataclass(frozen=True, eq=True, kw_only=True) +class UpdateFromRootCfg: + """Allow child object to be updated with values from the base config. + + For example: + + * `Checkpointer` reuse the `workdir` from the base config. + * `Evaluator`, `RgnStreams` reuse the `seed` from the base config. + + To use, either: + + * Set your dataclass fields to `ROOT_CFG_REF.xxx` to specify the fields should + be copied from the base config. + * Overwrite the `update_from_root_cfg` method, for a custom initialization. + + When using, make sure to also update the `kd.train.Config.__post_init__` to + call + `update_from_root_cfg`. Currently this not done automatically. + + Example: + + ```python + @dataclasses.dataclass + class MyConfig: + workdir: epath.Path = ROOT_CFG_REF.workdir + ``` + + Attributes: + _REUSE_FROM_ROOT_CFG: Mapping to + """ + + def update_from_root_cfg(self: _SelfT, root_cfg: config_lib.Config) -> _SelfT: + """Returns a copy of `self`, potentially with updated values.""" + fields_to_replace = {} + for f in dataclasses.fields(self): + default = f.default + if not isinstance(default, _FakeRootCfg): + continue + value = getattr(self, f.name) + if not isinstance(value, _FakeRootCfg): + continue + # value is a fake cfg, should be update + new_value = root_cfg + for attr in value.names[1:]: + new_value = getattr(root_cfg, attr) + fields_to_replace[f.name] = new_value + if not fields_to_replace: + return self + else: + return dataclasses.replace(self, **fields_to_replace)