From 69101d708a385979968885493c0288af98c3c1b1 Mon Sep 17 00:00:00 2001 From: PedroPareja Date: Tue, 10 Sep 2024 22:32:00 +0200 Subject: [PATCH 1/4] Fix to quantize merge models Fixes inf and nan values in the tensors that happen in the models produced by the merge of others. In those cases the inf values are replaced by MAX_VALUE. --- exllamav2/attn.py | 19 ++++++++++--------- exllamav2/embedding.py | 6 ++++-- exllamav2/headnorm.py | 10 ++++++---- exllamav2/layernorm.py | 10 ++++++---- exllamav2/linear.py | 21 ++++++++++----------- exllamav2/mlp.py | 11 ++++++----- exllamav2/model.py | 9 +++++---- exllamav2/module.py | 3 ++- exllamav2/moe_mlp.py | 7 ++++--- exllamav2/parallel_decoder.py | 3 ++- exllamav2/pos_embedding.py | 5 +++-- exllamav2/rmsnorm.py | 13 +++++++------ exllamav2/util.py | 11 ++++++++++- 13 files changed, 75 insertions(+), 53 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 4370dceb..7e2d2909 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import inspect import os +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -964,17 +965,17 @@ def forward(self, use_flash_attn = has_flash_attn and not cfg.no_flash_attn if isinstance(attn_params, ExLlamaV2Attention.PagedParams): - return self.forward_paged( + return substitute_inf_with_max(self.forward_paged( hidden_states, cache, attn_params, loras = loras, **kwargs - ) + )) if self.is_tp: if cache is not None and use_flash_attn: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -982,11 +983,11 @@ def forward(self, intermediates, loras, **kwargs, - ) + )) else: # TODO: Can't use the optimized forward function because it writes directly to a fixed output # tensor, and flash-attn currently has a bug that prevents that from working when q_len == 1 - return self.forward_tp_old( + return substitute_inf_with_max(self.forward_tp_old( hidden_states, cache, attn_params, @@ -994,7 +995,7 @@ def forward(self, intermediates, loras, **kwargs, - ) + )) if self.q_handle is None or intermediates: return self.forward_torch( @@ -1113,7 +1114,7 @@ def forward(self, if cfg.arch.clamp_hidden_states: hidden_states.clamp_(-65504, 65504) - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_tp( self, @@ -1428,9 +1429,9 @@ def forward_torch( if intermediates: return {"post_norm": post_norm, "attn_output": attn_output, - "hidden_states": hidden_states} + "hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def update_loras(self): diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index 48168b2d..b67ee649 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + EMBEDDING_INDEX: int = 1000000 class ExLlamaV2Embedding(ExLlamaV2Module): @@ -185,6 +187,6 @@ def forward( hidden_states = ctx.copy_pinned(0, hidden_states) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/headnorm.py b/exllamav2/headnorm.py index b890ba11..dbf6c9ec 100644 --- a/exllamav2/headnorm.py +++ b/exllamav2/headnorm.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + class ExLlamaV2HeadNorm(ExLlamaV2Module): name: str = "LayerNorm" @@ -122,9 +124,9 @@ def forward( self.variance_epsilon) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( self, @@ -146,8 +148,8 @@ def forward_torch( hidden_states = hidden_states.to(input_dtype) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/layernorm.py b/exllamav2/layernorm.py index 7b8f6c5b..0a2092cd 100644 --- a/exllamav2/layernorm.py +++ b/exllamav2/layernorm.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from exllamav2.util import substitute_inf_with_max + class ExLlamaV2LayerNorm(ExLlamaV2Module): name: str = "LayerNorm" @@ -119,9 +121,9 @@ def forward( hidden_states = norm.view(output_shape) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( @@ -139,8 +141,8 @@ def forward_torch( hidden_states = self.layernorm(hidden_states) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/linear.py b/exllamav2/linear.py index 5d6855dd..61504814 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -7,7 +7,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.compat import safe_move_tensor from exllamav2.tensor_p import BROADCAST_VC -from exllamav2.util import unpack_4bit, pack_4bit +from exllamav2.util import unpack_4bit, pack_4bit, substitute_inf_with_max import gc from typing import TYPE_CHECKING @@ -295,8 +295,7 @@ def temp_fwd_size(self) -> int: max_len = self.model.config.max_input_len if self.max_out_len is None else \ min(self.max_out_len, self.model.config.max_input_len) return self.out_features * max_len * self.model.config.max_batch_size * 4 + 128 - - + def forward( self, hidden_states: torch.Tensor, @@ -312,7 +311,7 @@ def forward( if self.is_tp: if self.out_features_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -322,9 +321,9 @@ def forward( force_recons, force_cuda, **kwargs - ) + )) elif self.in_features_tp: - return self.forward_tp_row( + return substitute_inf_with_max(self.forward_tp_row( hidden_states, cache, attn_params, @@ -334,7 +333,7 @@ def forward( force_recons, force_cuda, **kwargs - ) + )) else: assert False, "Unitialized TP linear layer" @@ -344,9 +343,9 @@ def forward( hidden_states_out = loras[0].lm_head(hidden_states) if intermediates: - return {"hidden_states": hidden_states_out} + return {"hidden_states": substitute_inf_with_max(hidden_states_out)} else: - return hidden_states_out + return substitute_inf_with_max(hidden_states_out) if self.q_handle is not None and not force_recons: @@ -380,9 +379,9 @@ def forward( hidden_states_out += torch.matmul(temp, lora_b) if intermediates: - return {"hidden_states": hidden_states_out} + return {"hidden_states": substitute_inf_with_max(hidden_states_out)} else: - return hidden_states_out + return substitute_inf_with_max(hidden_states_out) def forward_tp( diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 2d8282d5..27525932 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -9,6 +9,7 @@ from exllamav2.ext import exllamav2_ext as ext_c, none_tensor from exllamav2.lora import ExLlamaV2Lora from exllamav2.tensor_p import BROADCAST_ID, BROADCAST_RS +from exllamav2.util import substitute_inf_with_max # from line_profiler import profile from typing import TYPE_CHECKING @@ -288,7 +289,7 @@ def forward( ) -> torch.Tensor | dict[str: torch.Tensor]: if self.is_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -296,7 +297,7 @@ def forward( intermediates, loras, **kwargs - ) + )) cfg = self.model.config @@ -319,7 +320,7 @@ def forward( if cfg.arch.clamp_hidden_states: hidden_states.clamp_(-65504, 65504) - return hidden_states + return substitute_inf_with_max(hidden_states) # @profile @@ -457,9 +458,9 @@ def forward_torch( if intermediates: return {"post_norm": post_norm, "pre_down": y, - "hidden_states": hidden_states} + "hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def update_loras(self): diff --git a/exllamav2/model.py b/exllamav2/model.py index 3fb4f5be..eaf57d3f 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -2,6 +2,7 @@ import os, sys from exllamav2.architecture import RopeStyle +from exllamav2.util import substitute_inf_with_max min_version = (3, 8) if sys.version_info < min_version: @@ -820,9 +821,9 @@ def forward( if abort_event and abort_event.is_set(): return if "last_state" in result: - return result.get("logits"), result["last_state"] + return substitute_inf_with_max(result.get("logits")), substitute_inf_with_max(result["last_state"]) else: - return result.get("logits") + return substitute_inf_with_max(result.get("logits")) # Confirm that the input fits within the allocated cache space @@ -893,9 +894,9 @@ def forward( last_state = r.get("last_state") if last_state is None: - return result + return substitute_inf_with_max(result) else: - return result, last_state + return substitute_inf_with_max(result), substitute_inf_with_max(last_state) @torch.inference_mode() diff --git a/exllamav2/module.py b/exllamav2/module.py index 5bd672a6..c671bbf0 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -4,6 +4,7 @@ from exllamav2.config import ExLlamaV2Config from exllamav2.fasttensors import STFile from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -282,4 +283,4 @@ def forward(self, hidden_states, *args, **kwargs): hidden_states = self.post_forward(hidden_states, *args, **kwargs) hidden_states = safe_move_tensor(hidden_states, dev) - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/moe_mlp.py b/exllamav2/moe_mlp.py index 403c5ca1..d67ef371 100644 --- a/exllamav2/moe_mlp.py +++ b/exllamav2/moe_mlp.py @@ -7,6 +7,7 @@ from exllamav2.linear import ExLlamaV2Linear from exllamav2.lora import ExLlamaV2Lora from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -244,7 +245,7 @@ def forward( # ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]), pass_loras, pass_lora_temp) ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1])) - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_torch( @@ -313,9 +314,9 @@ def forward_torch( if intermediates: result["hidden_states"] = final_hidden_states - return result + return substitute_inf_with_max(result) else: - return final_hidden_states + return substitute_inf_with_max(final_hidden_states) def update_loras(self): diff --git a/exllamav2/parallel_decoder.py b/exllamav2/parallel_decoder.py index be772eca..902dd3f7 100644 --- a/exllamav2/parallel_decoder.py +++ b/exllamav2/parallel_decoder.py @@ -9,6 +9,7 @@ from exllamav2.lora import ExLlamaV2Lora from exllamav2.layernorm import ExLlamaV2LayerNorm from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -119,7 +120,7 @@ def forward( b = self.mlp.forward(b, cache, attn_params, past_len, intermediates, loras, **kwargs) hidden_states += a hidden_states += b - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_interm( diff --git a/exllamav2/pos_embedding.py b/exllamav2/pos_embedding.py index d7b85625..ac2e9eeb 100644 --- a/exllamav2/pos_embedding.py +++ b/exllamav2/pos_embedding.py @@ -4,6 +4,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.attn import ExLlamaV2Attention from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -118,6 +119,6 @@ def forward( hidden_states[b, target_a:target_b] += emb_slice if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index be6f00aa..8f640857 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -4,6 +4,7 @@ from exllamav2.module import ExLlamaV2Module from exllamav2.ext import exllamav2_ext as ext_c from exllamav2.compat import safe_move_tensor +from exllamav2.util import substitute_inf_with_max from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -114,7 +115,7 @@ def forward( ) -> torch.Tensor | dict[str: torch.Tensor]: if self.is_tp: - return self.forward_tp( + return substitute_inf_with_max(self.forward_tp( hidden_states, cache, attn_params, @@ -123,7 +124,7 @@ def forward( loras, output_fp32, **kwargs - ) + )) output_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -137,9 +138,9 @@ def forward( hidden_states = norm.view(output_shape) if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def forward_tp( @@ -198,9 +199,9 @@ def forward_torch( hidden_states *= self.weight if intermediates: - return {"hidden_states": hidden_states} + return {"hidden_states": substitute_inf_with_max(hidden_states)} else: - return hidden_states + return substitute_inf_with_max(hidden_states) def tp_split(self, broadcast_type: int): diff --git a/exllamav2/util.py b/exllamav2/util.py index fd44462e..f70673dd 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -367,4 +367,13 @@ def pack_4bit(unpacked: torch.Tensor): for i in range(8): packed |= (unpacked[:, i::8].to(torch.int64) << (i * 4)) packed = packed.to(torch.int32) - return packed \ No newline at end of file + return packed + + +# Example function to substitute inf with the maximum value of the type +def substitute_inf_with_max(tensor): + dtype = tensor.dtype + max_value = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max + tensor = torch.where(torch.isinf(tensor), max_value, tensor) + tensor = torch.where(torch.isnan(tensor), max_value, tensor) + return tensor From 9ffd4142f6734877be47745371a4fae4ee5d6967 Mon Sep 17 00:00:00 2001 From: PedroPareja Date: Tue, 10 Sep 2024 23:09:59 +0200 Subject: [PATCH 2/4] Update README.md - Fork description --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index cc93cc2e..3cebbcd7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +## About this Fork + +This fork includes the changes needed to quantize Merge models (models made by merging others instead of finetuning). + # ExLlamaV2 ExLlamaV2 is an inference library for running local LLMs on modern consumer GPUs. From 6d821af2e69d8f7069dc0d0051cd7303bb428ca7 Mon Sep 17 00:00:00 2001 From: PedroPareja Date: Tue, 10 Sep 2024 23:35:38 +0200 Subject: [PATCH 3/4] Update README.md Readme changes reverted --- README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/README.md b/README.md index 3cebbcd7..cc93cc2e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,3 @@ -## About this Fork - -This fork includes the changes needed to quantize Merge models (models made by merging others instead of finetuning). - # ExLlamaV2 ExLlamaV2 is an inference library for running local LLMs on modern consumer GPUs. From d9fed068a20b90680fe2396555126548b1cd3c1c Mon Sep 17 00:00:00 2001 From: PedroPareja Date: Tue, 10 Sep 2024 23:38:06 +0200 Subject: [PATCH 4/4] Update util.py --- exllamav2/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/util.py b/exllamav2/util.py index f70673dd..2c8ad13b 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -370,7 +370,7 @@ def pack_4bit(unpacked: torch.Tensor): return packed -# Example function to substitute inf with the maximum value of the type +# Function to substitute inf and NaN with the maximum value of the type def substitute_inf_with_max(tensor): dtype = tensor.dtype max_value = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max