Skip to content

Commit

Permalink
Remove model_with_aux from write metrics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699157175
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 22, 2024
1 parent 69e940d commit a4b0c35
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 52 deletions.
1 change: 0 additions & 1 deletion kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def evaluate(
step=step,
aux=merged_aux,
schedules={},
model_with_aux=self.model_with_aux,
log_summaries=True,
)
return merged_aux
Expand Down
12 changes: 1 addition & 11 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -201,15 +200,7 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
Expand Down Expand Up @@ -586,7 +577,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
1 change: 0 additions & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def train_impl(
step=i,
aux=aux,
schedules=trainer.schedules,
model_with_aux=trainstep.model_with_aux,
timer=chrono,
log_summaries=log_summaries,
)
Expand Down
40 changes: 1 addition & 39 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,10 @@ def init( # pylint:disable=missing-function-docstring
@jax.named_call
def forward(
self,
context: context_lib.Context | None = None,
context: context_lib.Context,
*,
rngs: rngs_lib.Rngs,
is_training: bool,
# DEPRECATED variables: Should be passed through `context` instead.
params=None,
batch=None,
step: int | None = None,
collections: _Collections | None = None,
) -> tuple[float, context_lib.Context]:
"""Forward pass of the model including losses.
Expand All @@ -268,45 +263,12 @@ def forward(
`batch`, `step`, and `collections` (and optionally `opt_state`).
rngs: Random numbers to use for the forward pass.
is_training: Whether to run the model in training or eval mode.
params: DEPRECATED: Should be passed through `context` instead.
batch: DEPRECATED: Should be passed through `context` instead.
step: DEPRECATED: Should be passed through `context` instead.
collections: DEPRECATED: Should be passed through `context` instead.
Returns:
loss_total: Total loss.
context: Context with the updated `loss_total`, `loss_states`,
`interms`, and `collections`.
"""
# New API: pass everything through `context`
if isinstance(context, context_lib.Context):
if any(v is not None for v in (params, batch, step, collections)):
raise ValueError(
"When calling `model_with_aux.forward(context)`, you should not"
" pass `params`, `batch`,... through kwargs, but rather through the"
" context."
)
# Should check that params, batch,... are correctly set in the context ?
else: # Legacy API (deprecated)
status.log(
"Warning: Calling `model_with_aux.forward(params)` is deprecated and"
" will be removed soon. Instead, all inputs should be passed"
" through context directly: `model_with_aux.forward(context)`."
)
# Params can be passed either as positional or keyword arguments:
if context is None: # `forward(params=params)`
assert params is not None, "Cannot pass both `params` and `context`"
else: # `forward(params)`
assert params is None, "Cannot pass both `params` and `context`"
params = context

context = context_lib.Context(
step=step,
batch=batch,
params=params,
collections=collections,
)
del params, batch, step, collections
args, kwargs = data_utils.get_model_inputs(self.model, context)
preds, collections = self.model.apply(
{"params": context.params} | context.collections,
Expand Down

0 comments on commit a4b0c35

Please sign in to comment.