From a343ad6d19d8122768837fc640bf80f8aeeb2d44 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Fri, 22 Nov 2024 05:02:42 -0800 Subject: [PATCH] Remove deprecated arguments PiperOrigin-RevId: 699133548 --- kauldron/train/train_step.py | 40 +----------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index d42debf4..21b61dd2 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -251,15 +251,10 @@ def init( # pylint:disable=missing-function-docstring @jax.named_call def forward( self, - context: context_lib.Context | None = None, + context: context_lib.Context, *, rngs: rngs_lib.Rngs, is_training: bool, - # DEPRECATED variables: Should be passed through `context` instead. - params=None, - batch=None, - step: int | None = None, - collections: _Collections | None = None, ) -> tuple[float, context_lib.Context]: """Forward pass of the model including losses. @@ -268,45 +263,12 @@ def forward( `batch`, `step`, and `collections` (and optionally `opt_state`). rngs: Random numbers to use for the forward pass. is_training: Whether to run the model in training or eval mode. - params: DEPRECATED: Should be passed through `context` instead. - batch: DEPRECATED: Should be passed through `context` instead. - step: DEPRECATED: Should be passed through `context` instead. - collections: DEPRECATED: Should be passed through `context` instead. Returns: loss_total: Total loss. context: Context with the updated `loss_total`, `loss_states`, `interms`, and `collections`. """ - # New API: pass everything through `context` - if isinstance(context, context_lib.Context): - if any(v is not None for v in (params, batch, step, collections)): - raise ValueError( - "When calling `model_with_aux.forward(context)`, you should not" - " pass `params`, `batch`,... through kwargs, but rather through the" - " context." - ) - # Should check that params, batch,... are correctly set in the context ? - else: # Legacy API (deprecated) - status.log( - "Warning: Calling `model_with_aux.forward(params)` is deprecated and" - " will be removed soon. Instead, all inputs should be passed" - " through context directly: `model_with_aux.forward(context)`." - ) - # Params can be passed either as positional or keyword arguments: - if context is None: # `forward(params=params)` - assert params is not None, "Cannot pass both `params` and `context`" - else: # `forward(params)` - assert params is None, "Cannot pass both `params` and `context`" - params = context - - context = context_lib.Context( - step=step, - batch=batch, - params=params, - collections=collections, - ) - del params, batch, step, collections args, kwargs = data_utils.get_model_inputs(self.model, context) preds, collections = self.model.apply( {"params": context.params} | context.collections,