From 5c75087aeee7081025370e10d1f571a11600f1ae Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 27 Dec 2024 16:33:44 +0000 Subject: [PATCH] Fix `model_accepts_loss_kwargs` for timm model (#35257) * Fix for timm model * Add comment --- .../models/timm_wrapper/modeling_timm_wrapper.py | 3 +++ src/transformers/trainer.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index dfb14dfccec4c6..47e8944583b4ca 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -82,6 +82,9 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): config_class = TimmWrapperConfig _no_split_modules = [] + # used in Trainer to avoid passing `loss_kwargs` to model forward + accepts_loss_kwargs = False + def __init__(self, *args, **kwargs): requires_backends(self, ["vision", "timm"]) super().__init__(*args, **kwargs) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c2327739549e5e..655d5b260c1f36 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -622,7 +622,15 @@ def __init__( else unwrapped_model.get_base_model().forward ) forward_params = inspect.signature(model_forward).parameters - self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()) + + # Check if the model has explicit setup for loss kwargs, + # if not, check if `**kwargs` are in model.forward + if hasattr(model, "accepts_loss_kwargs"): + self.model_accepts_loss_kwargs = model.accepts_loss_kwargs + else: + self.model_accepts_loss_kwargs = any( + k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() + ) self.neftune_noise_alpha = args.neftune_noise_alpha