From 0c65af12a115f7d13916b10d09030b170981134e Mon Sep 17 00:00:00 2001 From: Stephen Horvath Date: Thu, 15 Jun 2023 16:23:43 +1000 Subject: [PATCH] Fix `--load-8bit` for Intel ARC GPUs (#1697) --- fastchat/model/model_adapter.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index aa675dafb..facfbeea8 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -146,6 +146,13 @@ def load_model( replace_llama_attn_with_non_inplace_operations() elif device == "xpu": kwargs = {"torch_dtype": torch.bfloat16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) else: raise ValueError(f"Invalid device: {device}") @@ -185,12 +192,6 @@ def load_model( model.to(device) elif device == "xpu": - try: - import intel_extension_for_pytorch as ipex - except ImportError: - warnings.warn( - "Intel Extension for PyTorch is not installed, but is required for xpu inference." - ) model.eval() model = model.to("xpu") model = torch.xpu.optimize(model, dtype=torch.bfloat16, inplace=True)