Skip to content

Commit

Permalink
Remove deprecated arguments
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699133548
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 22, 2024
1 parent 69e940d commit a343ad6
Showing 1 changed file with 1 addition and 39 deletions.
40 changes: 1 addition & 39 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,10 @@ def init( # pylint:disable=missing-function-docstring
@jax.named_call
def forward(
self,
context: context_lib.Context | None = None,
context: context_lib.Context,
*,
rngs: rngs_lib.Rngs,
is_training: bool,
# DEPRECATED variables: Should be passed through `context` instead.
params=None,
batch=None,
step: int | None = None,
collections: _Collections | None = None,
) -> tuple[float, context_lib.Context]:
"""Forward pass of the model including losses.
Expand All @@ -268,45 +263,12 @@ def forward(
`batch`, `step`, and `collections` (and optionally `opt_state`).
rngs: Random numbers to use for the forward pass.
is_training: Whether to run the model in training or eval mode.
params: DEPRECATED: Should be passed through `context` instead.
batch: DEPRECATED: Should be passed through `context` instead.
step: DEPRECATED: Should be passed through `context` instead.
collections: DEPRECATED: Should be passed through `context` instead.
Returns:
loss_total: Total loss.
context: Context with the updated `loss_total`, `loss_states`,
`interms`, and `collections`.
"""
# New API: pass everything through `context`
if isinstance(context, context_lib.Context):
if any(v is not None for v in (params, batch, step, collections)):
raise ValueError(
"When calling `model_with_aux.forward(context)`, you should not"
" pass `params`, `batch`,... through kwargs, but rather through the"
" context."
)
# Should check that params, batch,... are correctly set in the context ?
else: # Legacy API (deprecated)
status.log(
"Warning: Calling `model_with_aux.forward(params)` is deprecated and"
" will be removed soon. Instead, all inputs should be passed"
" through context directly: `model_with_aux.forward(context)`."
)
# Params can be passed either as positional or keyword arguments:
if context is None: # `forward(params=params)`
assert params is not None, "Cannot pass both `params` and `context`"
else: # `forward(params)`
assert params is None, "Cannot pass both `params` and `context`"
params = context

context = context_lib.Context(
step=step,
batch=batch,
params=params,
collections=collections,
)
del params, batch, step, collections
args, kwargs = data_utils.get_model_inputs(self.model, context)
preds, collections = self.model.apply(
{"params": context.params} | context.collections,
Expand Down

0 comments on commit a343ad6

Please sign in to comment.