From 6c89a820ee96a505703e33cf6933697335b2ead4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:01:37 +0100 Subject: [PATCH] Revert "Fix gptq exllamav2 check (#152)" This reverts commit 5bf349dbbc5ecdbf6ca94ac70f80ac44bd84dcc0. --- optimum_benchmark/backends/pytorch/backend.py | 12 +++++++----- optimum_benchmark/trackers/energy.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 981e8baa..87d53290 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -278,13 +278,15 @@ def is_awq_quantized(self) -> bool: def is_exllamav2(self) -> bool: return (self.is_gptq_quantized or self.is_awq_quantized) and ( ( - getattr(self.pretrained_config, "quantization_config", None) is not None - and getattr(self.pretrained_config.quantization_config, "exllama_config", None) is not None - and self.pretrained_config.quantization_config.exllama_config.get("exllama_version", None) == 2 + hasattr(self.pretrained_config, "quantization_config") + and hasattr(self.pretrained_config.quantization_config, "exllama_config") + and "exllama_version" in self.pretrained_config.quantization_config.exllama_config + and self.pretrained_config.quantization_config.exllama_config["exllama_version"] == 2 ) or ( - self.config.quantization_config.get("exllama_config", None) is not None - and self.config.quantization_config.exllama_config.get("exllama_version", None) == 2 + hasattr(self.quantization_config, "exllama_config") + and "exllama_version" in self.quantization_config.exllama_config + and self.quantization_config.exllama_config["exllama_version"] == 2 ) ) diff --git a/optimum_benchmark/trackers/energy.py b/optimum_benchmark/trackers/energy.py index 750aa188..3161946a 100644 --- a/optimum_benchmark/trackers/energy.py +++ b/optimum_benchmark/trackers/energy.py @@ -56,13 +56,13 @@ def __sub__(self, other: "Energy") -> "Energy": """Enables subtraction of two Energy instances using the '-' operator.""" if self.unit != other.unit: raise ValueError("Energy units must match to perform subtraction") - + return Energy( cpu=self.cpu - other.cpu, gpu=self.gpu - other.gpu, ram=self.ram - other.ram, total=self.total - other.total, - unit=self.unit, + unit=self.unit )