diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 651e6d12..ea8aa0a1 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -436,9 +436,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - assert ( - kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1 - ), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" + assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, ( + "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" + ) return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode()