Skip to content

Commit

Permalink
Remove ModelWithAux in trainstep
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699160009
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 23, 2024
1 parent 69e940d commit 3e10889
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 138 deletions.
1 change: 0 additions & 1 deletion kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def evaluate(
step=step,
aux=merged_aux,
schedules={},
model_with_aux=self.model_with_aux,
log_summaries=True,
)
return merged_aux
Expand Down
3 changes: 2 additions & 1 deletion kauldron/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from kauldron.train.setup_utils import Setup
from kauldron.train.setup_utils import TqdmInfo
from kauldron.train.train_step import Auxiliaries
from kauldron.train.train_step import ModelWithAux
from kauldron.train.train_step import forward
from kauldron.train.train_step import forward_with_loss
from kauldron.train.train_step import TrainState
from kauldron.train.train_step import TrainStep
from kauldron.train.trainer_lib import Trainer
Expand Down
12 changes: 1 addition & 11 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -201,15 +200,7 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
Expand Down Expand Up @@ -586,7 +577,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
1 change: 0 additions & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def train_impl(
step=i,
aux=aux,
schedules=trainer.schedules,
model_with_aux=trainstep.model_with_aux,
timer=chrono,
log_summaries=log_summaries,
)
Expand Down
Loading

0 comments on commit 3e10889

Please sign in to comment.