From ef455a7bb97033dca1180820de2ecde3c29580a5 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 24 Jun 2024 02:59:49 +0200 Subject: [PATCH 01/40] Respect special tokens in WS server lefttrim_token() --- exllamav2/server/websocket_actions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exllamav2/server/websocket_actions.py b/exllamav2/server/websocket_actions.py index 12a6aa07..c0f095ca 100644 --- a/exllamav2/server/websocket_actions.py +++ b/exllamav2/server/websocket_actions.py @@ -87,11 +87,11 @@ def lefttrim_token(request, ws, server, response): text = request["text"] length = int(request["trimmed_length"]) - ids = server.tokenizer.cached_encode_str(text) + ids = server.tokenizer.cached_encode_str(text, encode_special_tokens = True) if ids.shape[-1] <= length: response["trimmed_text"] = text else: - response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:])[0] + response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:], decode_special_token = True)[0] async def infer(request, ws, server, response): From c387587e048a19289bd45f92294f52571c5648d7 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 24 Jun 2024 03:14:54 +0200 Subject: [PATCH 02/40] Don't cache encoding in lefttrim_token --- exllamav2/server/websocket_actions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exllamav2/server/websocket_actions.py b/exllamav2/server/websocket_actions.py index c0f095ca..f84f2853 100644 --- a/exllamav2/server/websocket_actions.py +++ b/exllamav2/server/websocket_actions.py @@ -87,11 +87,11 @@ def lefttrim_token(request, ws, server, response): text = request["text"] length = int(request["trimmed_length"]) - ids = server.tokenizer.cached_encode_str(text, encode_special_tokens = True) + ids = server.tokenizer.encode(text, encode_special_tokens = True) if ids.shape[-1] <= length: response["trimmed_text"] = text else: - response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:], decode_special_token = True)[0] + response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:], decode_special_tokens = True)[0] async def infer(request, ws, server, response): From 198bbdb117ce7fe4dc02bfcae31e9d3293a7ae5f Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Fri, 28 Jun 2024 00:22:59 +0200 Subject: [PATCH 03/40] Fix type hint --- exllamav2/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 7296b6e2..4aec7e4e 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -417,7 +417,7 @@ def load_gen( def load_autosplit( self, cache: ExLlamaV2CacheBase, - reserve_vram: int | None = None, + reserve_vram: int | list[int] | None = None, last_id_only: bool = False, callback: Callable[[int, int], None] | None = None, callback_gen: Callable[[int, int], None] | None = None, @@ -443,7 +443,7 @@ def callback_pb(a, b): def load_autosplit_gen( self, cache: ExLlamaV2CacheBase, - reserve_vram: int | None = None, + reserve_vram: int | list[int] | None = None, last_id_only: bool = False, callback: Callable[[int, int], None] | None = None, callback_gen: Callable[[int, int], None] | None = None @@ -466,6 +466,8 @@ def load_autosplit_gen( if reserve_vram is None: reserve_vram = [192 * 1024**2] + [64 * 1024**2] * (num_devices - 1) + elif isinstance(reserve_vram, int): + reserve_vram = [reserve_vram] * num_devices reserved_vram_tensors = [] minimum_reserve_tensor = None From e737e23e3060a80f81d91351ad59249f86319544 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 1 Jul 2024 02:15:34 +0200 Subject: [PATCH 04/40] Util function to sync only active devices --- exllamav2/generator/base.py | 3 ++- exllamav2/generator/dynamic.py | 5 +++-- exllamav2/util.py | 11 +++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py index c70cea29..236735e0 100644 --- a/exllamav2/generator/base.py +++ b/exllamav2/generator/base.py @@ -15,6 +15,7 @@ import threading from exllamav2.generator.hooks import ExLlamaV2PostSamplingHook, ExLlamaV2PostSamplingResult from exllamav2.embedding import EMBEDDING_INDEX +from exllamav2.util import cuda_sync_active class ExLlamaV2BaseGenerator: @@ -46,7 +47,7 @@ def warmup(self): input_ids = torch.zeros((1, 2), dtype = torch.long) self.model.forward(input_ids, cache = None, input_mask = None, preprocess_only = True) - torch.cuda.synchronize() + cuda_sync_active() def full(self): diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index ad01c140..b86c65bb 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -6,6 +6,7 @@ from exllamav2.cache import ExLlamaV2CacheBase, ExLlamaV2Cache_8bit from exllamav2.attn import ExLlamaV2Attention, assert_paged_attn from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import cuda_sync_active from concurrent.futures import ThreadPoolExecutor from exllamav2.compat import pairwise @@ -1013,7 +1014,7 @@ def iterate_draftmodel_gen(self, results: list): for job in self.active_jobs: if not job.is_prefill_done(): continue if job.time_first_token is None: - torch.cuda.synchronize() + cuda_sync_active() job.time_first_token = time.time() job_ids = job.get_input_ids_list() input_ids_list += job_ids @@ -1091,7 +1092,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): logit_mapping.append(len(input_ids_list)) if not job.is_prefill_done(): continue if job.time_first_token is None: - torch.cuda.synchronize() + cuda_sync_active() job.time_first_token = time.time() if draft_tokens is None: job_ids = job.get_input_ids_list(add_to_cache = True) diff --git a/exllamav2/util.py b/exllamav2/util.py index d0387229..170241de 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -110,6 +110,17 @@ def torch_slice(self, a: int | None, b: int | None): return s +def cuda_sync_active(): + """ + Calling torch.cuda.synchronize() will create a CUDA context on CUDA:0 even if that device is not being used. + This function synchronizes only devices actively used by Torch in the current process. + """ + for device_id in range(torch.cuda.device_count()): + device = torch.device(f'cuda:{device_id}') + if torch.cuda.memory_allocated(device) > 0: + torch.cuda.synchronize(device) + + def get_basic_progress(): progress = Progress( TextColumn("[progress.description]{task.description}"), From 8c2132453c7ac5ce9dabd379be52d4c52077ad6a Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:04:22 +0200 Subject: [PATCH 05/40] More debug output --- examples/chat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/chat.py b/examples/chat.py index 16b04057..b68a8c46 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -61,7 +61,7 @@ parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding") -parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings after each prompt") +parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings/stats after each prompt") parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context after every response") # Arrrgs @@ -393,8 +393,9 @@ def get_tokenized_context(max_len): else: sd_stats = "" + ctx_tokens = active_context.shape[-1] print() - print(col_sysprompt + f"(Response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default) + print(col_sysprompt + f"(Context: {ctx_tokens} tokens, response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default) # Optionally forget context after each response From 95e093a2b22725eb76f8a2c5a42e97b7dcfc1a5d Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:34:34 +0200 Subject: [PATCH 06/40] Chatbot: Ignore undefined special tokens --- examples/chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/chat.py b/examples/chat.py index b68a8c46..70963a9f 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -235,7 +235,9 @@ def get_tokenized_context(max_len): # Stop conditions -generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer)) +sc = prompt_format.stop_conditions(tokenizer) +sc = [x for x in sc if x] +generator.set_stop_conditions(sc) # ANSI color codes From e56cfe2219e6e3aded1e115e6edfb6fe617701d0 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:34:46 +0200 Subject: [PATCH 07/40] Chatbot: fix chatml template --- examples/chat_prompts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/chat_prompts.py b/examples/chat_prompts.py index 9ea0042c..00474eed 100644 --- a/examples/chat_prompts.py +++ b/examples/chat_prompts.py @@ -229,6 +229,7 @@ def subs_prompt(self): def stop_conditions(self, tokenizer): return \ [tokenizer.eos_token_id, + tokenizer.single_id("<|im_end|>"), """<|im_end|>"""] def encoding_options(self): From 38f44096ba799b55cce78fa76b571280a0c5dacc Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:37:18 +0200 Subject: [PATCH 08/40] Support InternLM2 --- exllamav2/architecture.py | 50 +++++++++++++++++++++++++++++++++++++++ exllamav2/attn.py | 6 ++--- exllamav2/config.py | 3 +-- exllamav2/linear.py | 6 +++-- exllamav2/module.py | 24 +++++++++++++++---- 5 files changed, 78 insertions(+), 11 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index df8b7de0..350849f3 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -9,6 +9,8 @@ ["ln_2"]] layer_keys_yi_norms = [["ln1", "input_layernorm"], ["ln2", "post_attention_layernorm"]] +layer_keys_internlm2_norms = [["attention_norm"], + ["ffn_norm"]] layer_keys_llama_attn = [["self_attn.q_proj"], ["self_attn.k_proj"], ["self_attn.v_proj"], @@ -17,6 +19,10 @@ ["self_attn.c_attn", "self_attn.k_proj"], ["self_attn.c_attn", "self_attn.v_proj"], ["self_attn.o_proj"]] +layer_keys_internlm2_attn = [["self_attn.wqkv", "self_attn.q_proj"], + ["self_attn.wqkv", "self_attn.k_proj"], + ["self_attn.wqkv", "self_attn.v_proj"], + ["self_attn.o_proj"]] layer_keys_dbrx_attn = [["self_attn.Wqkv", "self_attn.q_proj"], ["self_attn.Wqkv", "self_attn.k_proj"], ["self_attn.Wqkv", "self_attn.v_proj"], @@ -28,6 +34,9 @@ layer_keys_llama_mlp = [["mlp.down_proj"], ["mlp.gate_proj"], ["mlp.up_proj"]] +layer_keys_internlm2_mlp = [["feed_forward.w1"], + ["feed_forward.w2"], + ["feed_forward.w3"]] layer_keys_phi3_mlp = [["mlp.down_proj"], ["mlp.gate_up_proj", "mlp.gate_proj"], ["mlp.gate_up_proj", "mlp.up_proj"]] @@ -76,6 +85,10 @@ ("$h.", "model.layers."), ("$wte.", "model.embed_tokens."), ("$wpe.", "model.wpe.")] +internlm2_keymap = [("$output.", "lm_head."), + ("$model.tok_embeddings.", "model.embed_tokens."), + (".attention.", ".self_attn."), + (".wo.", ".o_proj.")] class RopeStyle(Enum): NONE = 0 @@ -100,6 +113,8 @@ def __init__(self, arch_string, read_config): self.orig_weights_transposed = False self.logit_scale_basedim = False + self.fused_qkv_altpack = False + # Mistral if arch_string == "MistralForCausalLM": @@ -586,6 +601,41 @@ def __init__(self, arch_string, read_config): self.scale_attn_weights = False self.logit_scale_basedim = True + # InternLM2 + + if arch_string == "InternLM2ForCausalLM": + arch_recognized = True + self.layer_keys += \ + layer_keys_internlm2_norms + \ + layer_keys_internlm2_attn + \ + layer_keys_internlm2_mlp + self.expect_keys += \ + expect_keys_llama + self.norm_eps_key = "rms_norm_eps" + self.attention_bias_qkv = False + self.attention_bias_o = False + self.mlp_bias = False + self.mlp_gate = True + self.mlp_key_gate = ".feed_forward.w1" + self.mlp_key_up = ".feed_forward.w3" + self.mlp_key_down = ".feed_forward.w2" + self.mlp_act_func = "silu" + self.is_moe = False + self.norm = "rmsnorm" + self.lm_head_key = "lm_head" + self.normalize_embeddings = False + self.norm_key_1 = ".attention_norm" + self.norm_key_2 = ".ffn_norm" + self.norm_constant_bias = 0 + self.parallel_decoder_blocks = False + self.requires_bos = False + self.rope_style = RopeStyle.NEOX + self.keymap = internlm2_keymap + self.fused_qkv_key = "wqkv" + self.fused_qkv_altpack = True + self.mqa = False + self.scale_attn_weights = False + # Llama (default + fallback) if arch_string != "LlamaForCausalLM" and not arch_recognized: diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 8157760d..609b8751 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -303,9 +303,9 @@ def __init__(self, f_d = f_c + cfg.num_key_value_heads * cfg.head_dim f_key = (key + ".self_attn." + cfg.arch.fused_qkv_key) if cfg.arch.fused_qkv_key else None - self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b) - self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c) - self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d) + self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = cfg.arch.fused_qkv_altpack) + self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = cfg.arch.fused_qkv_altpack) + self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = cfg.arch.fused_qkv_altpack) self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, cfg.arch.attention_bias_o, prescale = cfg.scale_depth) if cfg.use_qk_norm: diff --git a/exllamav2/config.py b/exllamav2/config.py index 305b797c..cc283e59 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -175,8 +175,7 @@ def prepare(self, no_tensors: bool = False): self.generation_config['eos_token_id'] = [eos_token_id_as_int] else: self.generation_config['eos_token_id'] = None - - + # Model architecture assert len(read_config["architectures"]) == 1, "Multiple architectures defined in config.json" diff --git a/exllamav2/linear.py b/exllamav2/linear.py index be505eee..e0e455e6 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -53,7 +53,8 @@ def __init__(self, f_key: str = None, f_beg: int = None, f_end: int = None, - is_sub_module: bool = True): + is_sub_module: bool = True, + altpack_qkv: bool = False): super().__init__(model, key) self.is_sub_module = is_sub_module @@ -85,6 +86,7 @@ def __init__(self, self.f_key = f_key self.f_beg = f_beg self.f_end = f_end + self.altpack_qkv = altpack_qkv self.assumed_footprint = in_features * (out_features + self.padding) * 2 + 128 @@ -94,7 +96,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device_tensors: bool = True): - if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features) + if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv) if w is None: w = self.load_weight() # Load quantized linear layer from dictionary diff --git a/exllamav2/module.py b/exllamav2/module.py index 1c5863dd..b356a993 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -144,25 +144,41 @@ def load_weight_fused(self, f_beg: int, f_end: int, in_feat: int, - out_feat: int): + out_feat: int, + altpack_qkv: bool): res = [] for key in [f_key, f_key + ".weight", f_key + ".bias"]: - filename = self.model.config.tensor_file_map.get(key) + cfg = self.model.config + filename = cfg.tensor_file_map.get(key) if not filename: continue - stfile = STFile.open(filename, fast = self.model.config.fasttensors, keymap = self.model.config.arch.keymap) + stfile = STFile.open(filename, fast = cfg.fasttensors, keymap = cfg.arch.keymap) # tensor = stfile.get_tensor(key, device = self.device()).half() tensor = stfile.get_tensor(key, device = "cpu", cached = True, out_dtype = torch.half) - if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2: + + if cfg.arch.orig_weights_transposed and len(tensor.shape) == 2: tensor = tensor.T + + if altpack_qkv: + ts = tensor.shape + h, gs, d = cfg.num_key_value_heads, cfg.num_key_value_groups + 2, cfg.head_dim + tensor = tensor.view(h, gs, d, -1).transpose(0, 1).reshape(ts) + tensor = tensor[f_beg:f_end] + + if altpack_qkv: + ts = tensor.shape + h, gs, d = cfg.num_key_value_heads, (f_end - f_beg) // cfg.num_key_value_heads // cfg.head_dim, cfg.head_dim + tensor = tensor.view(gs, h, d, -1).transpose(0, 1).reshape(ts) + if not key.endswith(".bias"): if in_feat != out_feat and \ tensor.shape[1] == out_feat and \ tensor.shape[0] == in_feat: tensor = tensor.T + tensor = tensor.contiguous().to(self.device()) res.append(nn.Parameter(tensor)) From f1179ff200ff54b6dca676c2b0f83b4a905dfab3 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:37:51 +0200 Subject: [PATCH 09/40] Add ppl-over-seqlen test --- test_inference.py | 71 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/test_inference.py b/test_inference.py index cd884d5b..376b3f98 100644 --- a/test_inference.py +++ b/test_inference.py @@ -50,6 +50,7 @@ parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache") parser.add_argument("-eq6", "--eval_token_q6", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q6 cache") parser.add_argument("-eq8", "--eval_token_q8", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q8 cache") +parser.add_argument("-ecl", "--eval_context_lens", action = "store_true", help = "Evaluate perplexity at range of context lengths") # parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)") parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)") parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt") @@ -83,6 +84,9 @@ if args.gpu_split: print(" ## Can only use one GPU when streaming layers") sys.exit() + if args.eval_context_lens and args.stream_layers: + print(" ## eval_context_lens not compatible with stream_layers") + sys.exit() if args.eval_dataset: if args.length and args.eval_length != args.length: print(" !! Overriding model context length to match eval row length") @@ -279,13 +283,22 @@ boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long) eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1) - logprob_sum = 0.0 - logprob_count = 0 + if args.eval_context_lens: + logprob_sum = [] + logprob_count = [] + else: + logprob_sum = 0.0 + logprob_count = 0 - def ppl(input_ids__, logits__, lengths__): + def ppl(input_ids__, logits__, lengths__, bins = False): - logprob_sum_ = 0.0 - logprob_count_ = 0 + if bins: + num_bins = (max(lengths__) + 255) // 256 + logprob_sum_ = [0.0] * num_bins + logprob_count_ = [0] * num_bins + else: + logprob_sum_ = 0.0 + logprob_count_ = 0 assert logits__.shape[0] == input_ids__.shape[0] ll = logits__.shape[1] @@ -295,7 +308,10 @@ def ppl(input_ids__, logits__, lengths__): logits_ = logits__[bi:bi+1, cl:, :] input_ids_ = input_ids__[bi:bi+1, cl:] - chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1 + if bins: + chunksize = 256 + else: + chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1 b_ = 0 while b_ < logits_.shape[1]: a_ = b_ @@ -306,8 +322,14 @@ def ppl(input_ids__, logits__, lengths__): log_probs = F.log_softmax(logits_f, dim=-1) token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) - logprob_sum_ += token_log_probs.sum().item() - logprob_count_ += target_ids.numel() + if bins: + # for cbin in range(a_ // 256 + 1): + cbin = a_ // 256 + logprob_sum_[cbin] += token_log_probs.sum().item() + logprob_count_[cbin] += target_ids.numel() + else: + logprob_sum_ += token_log_probs.sum().item() + logprob_count_ += target_ids.numel() return logprob_sum_, logprob_count_ @@ -379,15 +401,30 @@ def ppl(input_ids__, logits__, lengths__): logits = model.forward(input_ids, cache) logits = logits[:, :-1, :] - logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1]) - logprob_sum += logprob_sum__ - logprob_count += logprob_count__ - - print() - - mean_log_prob = logprob_sum / logprob_count - perplexity = math.exp(-mean_log_prob) - print(f" -- Evaluation perplexity: {perplexity:.4f}") + logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens) + if args.eval_context_lens: + while len(logprob_sum) < len(logprob_sum__): + logprob_sum.append(0.0) + logprob_count.append(0) + for j in range(len(logprob_sum__)): + logprob_sum[j] += logprob_sum__[j] + logprob_count[j] += logprob_count__[j] + else: + logprob_sum += logprob_sum__ + logprob_count += logprob_count__ + + if not args.eval_context_lens: + print() + mean_log_prob = logprob_sum / logprob_count + perplexity = math.exp(-mean_log_prob) + print(f" -- Evaluation perplexity: {perplexity:.4f}") + else: + print() + for j in range(len(logprob_sum__)): + mean_log_prob = logprob_sum[j] / logprob_count[j] + perplexity = math.exp(-mean_log_prob) + dl = min((j + 1) * 256, eval_length) + print(f" -- Evaluation perplexity: {dl} {perplexity:.4f}") def test_ppl_token(): global logprob_sum, logprob_count, i, input_ids From a1aceaae203872777bbc42886333d8a0bd92a3ee Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:19:24 +0200 Subject: [PATCH 10/40] Add final logit softcapping --- exllamav2/config.py | 3 + exllamav2/exllamav2_ext/cuda/softcap.cu | 78 ++++++++++++++++++++++++ exllamav2/exllamav2_ext/cuda/softcap.cuh | 24 ++++++++ exllamav2/exllamav2_ext/ext_bindings.cpp | 5 ++ exllamav2/exllamav2_ext/ext_element.cpp | 48 +++++++++++++++ exllamav2/exllamav2_ext/ext_element.h | 6 ++ exllamav2/ext.py | 2 + exllamav2/model.py | 4 ++ 8 files changed, 170 insertions(+) create mode 100644 exllamav2/exllamav2_ext/cuda/softcap.cu create mode 100644 exllamav2/exllamav2_ext/cuda/softcap.cuh create mode 100644 exllamav2/exllamav2_ext/ext_element.cpp create mode 100644 exllamav2/exllamav2_ext/ext_element.h diff --git a/exllamav2/config.py b/exllamav2/config.py index cc283e59..e2e182c5 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -100,6 +100,7 @@ class ExLlamaV2Config: scale_depth: float scale_emb: float use_qk_norm: bool + final_logit_softcapping: float | None checkpoint_fused_mlp: bool @@ -242,6 +243,8 @@ def prepare(self, no_tensors: bool = False): else: self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers) + self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None) + # Positional embeddings self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0) diff --git a/exllamav2/exllamav2_ext/cuda/softcap.cu b/exllamav2/exllamav2_ext/cuda/softcap.cu new file mode 100644 index 00000000..e2bd6aaa --- /dev/null +++ b/exllamav2/exllamav2_ext/cuda/softcap.cu @@ -0,0 +1,78 @@ +#include "softcap.cuh" +#include "util.cuh" +#include "../config.h" +#include "matrix_view.cuh" + +#define NUM_THREADS 256 + +__global__ void cuda_softcap_kernel +( + float* __restrict__ x, + const uint64_t numel, + const float scale +) +{ + uint64_t idx = (uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x; + if (idx >= numel) return; + + float v = x[idx]; + v /= scale; + v = tanhf(v); + v *= scale; + x[idx] = v; +} + +void softcap_cuda_ +( + float* x, + const uint64_t numel, + const float scale +) +{ + dim3 blockDim, gridDim; + blockDim.x = NUM_THREADS; + gridDim.x = DIVIDE(numel, NUM_THREADS); + + cuda_softcap_kernel<<>>(x, numel, scale); +} + +// TODO: Profile + +__global__ void h_cuda_softcap_kernel +( + half* __restrict__ x, + const uint64_t numel, + const float scale +) +{ + uint64_t idx = (uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x; + idx *= 2; + if (idx >= numel) return; + half2* x2 = (half2*)(x + idx); + half2 v01 = *x2; + float v0 = __low2float(v01); + float v1 = __high2float(v01); + v0 /= scale; + v1 /= scale; + v0 = tanhf(v0); + v1 = tanhf(v1); + v0 *= scale; + v1 *= scale; + v01 = __floats2half2_rn(v0, v1); + *x2 = v01; +} + +void h_softcap_cuda_ +( + half* x, + const uint64_t numel, + const float scale +) +{ + dim3 blockDim, gridDim; + blockDim.x = NUM_THREADS; + gridDim.x = DIVIDE(numel / 2, NUM_THREADS); + + h_cuda_softcap_kernel<<>>(x, numel, scale); +} + diff --git a/exllamav2/exllamav2_ext/cuda/softcap.cuh b/exllamav2/exllamav2_ext/cuda/softcap.cuh new file mode 100644 index 00000000..4ea2e6f8 --- /dev/null +++ b/exllamav2/exllamav2_ext/cuda/softcap.cuh @@ -0,0 +1,24 @@ +#ifndef _softcap_cuh +#define _softcap_cuh + +#include +#include +#include +#include +#include + +void softcap_cuda_ +( + float* x, + const uint64_t numel, + const float scale +); + +void h_softcap_cuda_ +( + half* x, + const uint64_t numel, + const float scale +); + +#endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_bindings.cpp b/exllamav2/exllamav2_ext/ext_bindings.cpp index 0b61aea3..0fbb5a43 100644 --- a/exllamav2/exllamav2_ext/ext_bindings.cpp +++ b/exllamav2/exllamav2_ext/ext_bindings.cpp @@ -19,6 +19,7 @@ #include "ext_gemm.h" #include "ext_norm.h" #include "ext_rope.h" +#include "ext_element.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -113,4 +114,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // rope m.def("rope_", &rope_, "rope_"); + + // element + + m.def("softcap_", &softcap_, "softcap_"); } \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_element.cpp b/exllamav2/exllamav2_ext/ext_element.cpp new file mode 100644 index 00000000..968c88af --- /dev/null +++ b/exllamav2/exllamav2_ext/ext_element.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" +#include "cuda/softcap.cuh" +#include "cpp/util.h" + +// Apply softcapping inplace: x = scale * tanh(x/scale) + +void softcap_ +( + torch::Tensor x, + float scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + uint64_t numel = x.numel(); + + if (x.dtype() == torch::kFloat) + { + softcap_cuda_ + ( + (float*) x.data_ptr(), + numel, + scale + ); + } + else if (x.dtype() == torch::kHalf) + { + h_softcap_cuda_ + ( + (half*) x.data_ptr(), + numel, + scale + ); + } + else + { + TORCH_CHECK(false, "softcap_ wrong dtype"); + } +} diff --git a/exllamav2/exllamav2_ext/ext_element.h b/exllamav2/exllamav2_ext/ext_element.h new file mode 100644 index 00000000..97e6c706 --- /dev/null +++ b/exllamav2/exllamav2_ext/ext_element.h @@ -0,0 +1,6 @@ + +void softcap_ +( + torch::Tensor x, + float scale +); diff --git a/exllamav2/ext.py b/exllamav2/ext.py index 2fc40401..53681510 100644 --- a/exllamav2/ext.py +++ b/exllamav2/ext.py @@ -213,6 +213,7 @@ def find_msvc(): "ext_rope.cpp", "ext_safetensors.cpp", "ext_sampling.cpp", + "ext_element.cpp", "cuda/h_add.cu", "cuda/h_gemm.cu", "cuda/lora.cu", @@ -228,6 +229,7 @@ def find_msvc(): "cuda/rope.cu", "cuda/cache.cu", "cuda/util.cu", + "cuda/softcap.cu", "cuda/comp_units/kernel_select.cu", "cuda/comp_units/unit_gptq_1.cu", "cuda/comp_units/unit_gptq_2.cu", diff --git a/exllamav2/model.py b/exllamav2/model.py index 4aec7e4e..a1f8575b 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -53,6 +53,7 @@ # from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak from exllamav2.util import get_basic_progress # from line_profiler import profile +from exllamav2.ext import exllamav2_ext as ext_c, none_tensor def _torch_device(idx): @@ -917,6 +918,9 @@ def forward_chunk(self, # if x is not None and self.config.logit_scale != 1: # x.mul_(self.config.logit_scale) + if x is not None and self.config.final_logit_softcapping: + ext_c.softcap_(x, self.config.final_logit_softcapping) + # Set padding logits to -inf if x is not None: From d2b17db5deab9464fb417ae17fec49c6a043d1bd Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:19:47 +0200 Subject: [PATCH 11/40] Read special tokens list from tokenizer_config.json --- exllamav2/tokenizer/tokenizer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/exllamav2/tokenizer/tokenizer.py b/exllamav2/tokenizer/tokenizer.py index fbb458cc..1b9f4396 100644 --- a/exllamav2/tokenizer/tokenizer.py +++ b/exllamav2/tokenizer/tokenizer.py @@ -137,6 +137,17 @@ def __init__(self, config, lazy_init = False, force_json = False): with open(added_tokens_path, encoding = "utf8") as f: self.extended_piece_to_id.update(json.load(f)) + # Add special tokens from tokenizer_config.json + + if self.tokenizer_config_dict and "added_tokens_decoder" in self.tokenizer_config_dict: + atd = self.tokenizer_config_dict["added_tokens_decoder"] + for (k, v) in atd.items(): + if not v["special"]: + continue + token_id = int(k) + token_str = v["content"] + self.extended_piece_to_id[token_str] = token_id + # Remove unspecial added tokens that exist in the base tokenizer already, but only if they decode correctly # see https://github.com/huggingface/tokenizers/issues/1392 From c8e5cedfb36b8f670270ba264ca8422c16d0557a Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:20:31 +0200 Subject: [PATCH 12/40] Example Gemma template --- examples/util.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/util.py b/examples/util.py index 8ddc6b7f..45c8b1ee 100644 --- a/examples/util.py +++ b/examples/util.py @@ -29,6 +29,12 @@ def format_prompt(prompt_format, sp, p): f"{p}<|im_end|>\n" f"<|im_start|>assistant\n" ) + elif prompt_format == "gemma": + return ( + f"user\n" + f"{p}\n" + f"model\n" + ) def get_stop_conditions(prompt_format, tokenizer): if prompt_format == "llama": @@ -37,7 +43,8 @@ def get_stop_conditions(prompt_format, tokenizer): return [tokenizer.single_id("<|eot_id|>")] elif prompt_format == "granite": return [tokenizer.eos_token_id, "\n\nQuestion:"] - + elif prompt_format == "gemma": + return [tokenizer.eos_token_id, ""] # Cached dataset loader From c1a572bf891edb3d916f1f03fefa243c3d4fd00f Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:24:28 +0200 Subject: [PATCH 13/40] Support pre and post layernorms --- exllamav2/architecture.py | 3 + exllamav2/attn.py | 69 +++++++++++++------- exllamav2/conversion/compile.py | 12 +++- exllamav2/exllamav2_ext/cuda/layer_norm.cu | 16 +++-- exllamav2/exllamav2_ext/cuda/layer_norm.cuh | 3 +- exllamav2/exllamav2_ext/cuda/q_attn.cu | 21 ++++++- exllamav2/exllamav2_ext/cuda/q_attn.cuh | 8 ++- exllamav2/exllamav2_ext/cuda/q_mlp.cu | 26 ++++++-- exllamav2/exllamav2_ext/cuda/q_mlp.cuh | 10 ++- exllamav2/exllamav2_ext/cuda/rms_norm.cu | 16 +++-- exllamav2/exllamav2_ext/cuda/rms_norm.cuh | 3 +- exllamav2/exllamav2_ext/ext_qattn.cpp | 9 ++- exllamav2/exllamav2_ext/ext_qattn.h | 4 +- exllamav2/exllamav2_ext/ext_qmlp.cpp | 9 ++- exllamav2/exllamav2_ext/ext_qmlp.h | 4 +- exllamav2/mlp.py | 70 ++++++++++++++------- 16 files changed, 208 insertions(+), 75 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 350849f3..694eba2f 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -113,6 +113,9 @@ def __init__(self, arch_string, read_config): self.orig_weights_transposed = False self.logit_scale_basedim = False + self.norm_key_1_post = None + self.norm_key_2_post = None + self.fused_qkv_altpack = False # Mistral diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 609b8751..038a278d 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -75,7 +75,8 @@ class ExLlamaV2Attention(ExLlamaV2Module): name: str = "Attention" layer_idx: int - input_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None q_proj: ExLlamaV2Linear | None k_proj: ExLlamaV2Linear | None v_proj: ExLlamaV2Linear | None @@ -291,11 +292,14 @@ def __init__(self, if self.has_norm: if cfg.arch.norm == "layernorm": - self.input_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1) + self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1) + self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None elif cfg.arch.norm == "rmsnorm": - self.input_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1) + self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1) + self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None else: - self.input_layernorm = None + self.pre_layernorm = None + self.post_layernorm = None f_a = 0 f_b = cfg.num_attention_heads * cfg.head_dim @@ -319,11 +323,12 @@ def __init__(self, self.k_proj, self.v_proj, self.o_proj] - if self.has_norm: - self.submodules += [self.input_layernorm] + if self.pre_layernorm: + self.submodules += [self.pre_layernorm] + if self.post_layernorm: + self.submodules += [self.post_layernorm] if cfg.use_qk_norm: - self.submodules += [self.q_norm, - self.k_norm] + self.submodules += [self.q_norm, self.k_norm] # if cfg.arch.scale_attn_weights: # self.unscale_factor = self.layer_idx + 1 @@ -340,7 +345,8 @@ def numel(self) -> int: self.v_proj.numel() + \ self.o_proj.numel() - if self.input_layernorm is not None: numel += self.input_layernorm.numel() + if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() + if self.post_layernorm is not None: numel += self.post_layernorm.numel() if self.q_norm is not None: numel += self.q_norm.numel() if self.k_norm is not None: numel += self.k_norm.numel() @@ -352,7 +358,8 @@ def load(self): cfg = self.model.config - if self.input_layernorm is not None: self.input_layernorm.load() + if self.pre_layernorm is not None: self.pre_layernorm.load() + if self.post_layernorm is not None: self.post_layernorm.load() self.q_proj.load() self.k_proj.load() self.v_proj.load() @@ -374,16 +381,23 @@ def load(self): # self.temp_kv = device_tensors.get_scratch_slice(self.temp_kv_size()) if cfg.num_attention_heads != cfg.num_key_value_heads else None if self.has_norm: - norm_weight = self.input_layernorm.weight if self.input_layernorm.weight is not None else none_tensor - norm_bias = self.input_layernorm.bias if self.input_layernorm.bias is not None else none_tensor - is_rms = isinstance(self.input_layernorm, ExLlamaV2RMSNorm) - eps = self.input_layernorm.variance_epsilon + norm_weight = self.pre_layernorm.weight if self.pre_layernorm.weight is not None else none_tensor + norm_bias = self.pre_layernorm.bias if self.pre_layernorm.bias is not None else none_tensor + is_rms = isinstance(self.pre_layernorm, ExLlamaV2RMSNorm) + eps = self.pre_layernorm.variance_epsilon else: norm_weight = none_tensor norm_bias = none_tensor is_rms = False eps = 0 + if self.post_layernorm is not None: + post_norm_weight = self.post_layernorm.weight if self.post_layernorm.weight is not None else none_tensor + post_norm_bias = self.post_layernorm.bias if self.post_layernorm.bias is not None else none_tensor + else: + post_norm_weight = none_tensor + post_norm_bias = none_tensor + if self.q_norm is None: q_norm = none_tensor else: @@ -417,7 +431,9 @@ def load(self): self.has_residual, cfg.arch.rope_style.value, q_norm, - k_norm + k_norm, + post_norm_weight, + post_norm_bias, ) @@ -426,7 +442,8 @@ def unload(self): ext_c.free_q_attn(self.q_handle) self.q_handle = None - if self.input_layernorm is not None: self.input_layernorm.unload() + if self.pre_layernorm is not None: self.pre_layernorm.unload() + if self.post_layernorm is not None: self.post_layernorm.unload() if self.q_proj is not None: self.q_proj.unload() if self.k_proj is not None: self.k_proj.unload() if self.v_proj is not None: self.v_proj.unload() @@ -445,8 +462,10 @@ def weight_footprint(self): self.k_proj.weight_footprint() + \ self.v_proj.weight_footprint() + \ self.o_proj.weight_footprint() - if self.input_layernorm is not None: - fp += self.input_layernorm.weight_footprint() + if self.pre_layernorm is not None: + fp += self.pre_layernorm.weight_footprint() + if self.post_layernorm is not None: + fp += self.post_layernorm.weight_footprint() if self.q_norm is not None: fp += self.q_norm.weight_footprint() if self.k_norm is not None: @@ -530,7 +549,8 @@ def temp_attn_size(self): def set_device_idx(self, idx): super().set_device_idx(idx) - if self.input_layernorm is not None: self.input_layernorm.set_device_idx(idx) + if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) + if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) self.q_proj.set_device_idx(idx) self.k_proj.set_device_idx(idx) self.v_proj.set_device_idx(idx) @@ -614,7 +634,7 @@ def forward_paged(self, ) else: residual = hidden_states - hidden_states = self.input_layernorm.forward(hidden_states) if self.has_norm else hidden_states + hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states q = self.q_proj.forward(hidden_states, loras = loras) k = self.k_proj.forward(hidden_states, loras = loras) v = self.v_proj.forward(hidden_states, loras = loras) @@ -694,6 +714,8 @@ def forward_paged(self, ) else: hidden_states = self.o_proj.forward(attn_output, loras = loras) + if self.post_layernorm: + hidden_states = self.post_layernorm.forward(hidden_states) if self.has_residual: hidden_states += residual @@ -942,7 +964,7 @@ def forward_torch(self, # Project q, k, v residual = hidden_states - post_norm = self.input_layernorm.forward(hidden_states) if self.has_norm else hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states query_states = self.q_proj.forward(post_norm, loras = loras) key_states = self.k_proj.forward(post_norm, loras = loras) @@ -1013,6 +1035,11 @@ def forward_torch(self, attn_proj = self.o_proj.forward(attn_output, loras = loras) + # Post layernorm + + if self.post_layernorm: + attn_proj = self.post_layernorm.forward(attn_proj) + # Add residual connection hidden_states = (attn_proj + residual) if self.has_residual else attn_proj diff --git a/exllamav2/conversion/compile.py b/exllamav2/conversion/compile.py index 22733965..4a9320fe 100644 --- a/exllamav2/conversion/compile.py +++ b/exllamav2/conversion/compile.py @@ -32,6 +32,8 @@ def _dsize(d): def get_f_module(job, module): + if module is None: return None + mod_dict = {} module.load() w = module.get_weight() @@ -77,7 +79,10 @@ def compile_model(job, save_fn, model): if isinstance(module, ExLlamaV2Attention): - d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.pre_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.post_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) @@ -86,7 +91,10 @@ def compile_model(job, save_fn, model): if isinstance(module, ExLlamaV2MLP): has_gate = model.config.arch.mlp_gate - d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.pre_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.post_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) if has_gate: d = get_q_module(job, module.gate_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.up_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.down_proj); out_dict.update(d); current_size += _dsize(d) diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cu b/exllamav2/exllamav2_ext/cuda/layer_norm.cu index 9d112338..e65b21fc 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cu @@ -24,7 +24,8 @@ typedef void (*fp_layer_norm_kernel) const float, const float, const int, - const int + const int, + const bool ); template @@ -37,7 +38,8 @@ __global__ void layer_norm_kernel const float epsilon, const float r_dim, const int rows, - const int dim + const int dim, + const bool add_residual ) { int warp_id = threadIdx.x / WARP_SIZE; @@ -149,7 +151,10 @@ __global__ void layer_norm_kernel half2 nh = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); if (b) nh = __hadd2(nh, b2[column]); // Optional bias - y_row[column] = nh; + if (add_residual) + y_row[column] = __hadd2(nh, y_row[column]); + else + y_row[column] = nh; } } @@ -185,7 +190,8 @@ void layer_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual ) { dim3 blockDim, gridDim; @@ -198,5 +204,5 @@ void layer_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim); + kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual); } diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cuh b/exllamav2/exllamav2_ext/cuda/layer_norm.cuh index 2b10d4d6..f780794c 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cuh +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cuh @@ -14,7 +14,8 @@ void layer_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual = false ); #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cu b/exllamav2/exllamav2_ext/cuda/q_attn.cu index bb8b2d4d..d536ce04 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cu +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cu @@ -93,7 +93,9 @@ QAttn::QAttn bool _has_residual, int _rope_style, half* _q_norm, - half* _k_norm + half* _k_norm, + half* _post_layernorm, + half* _post_layernorm_bias ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -117,7 +119,9 @@ QAttn::QAttn has_residual(_has_residual), rope_style(_rope_style), q_norm(_q_norm), - k_norm(_k_norm) + k_norm(_k_norm), + post_layernorm(_post_layernorm), + post_layernorm_bias(_post_layernorm_bias) { } @@ -202,7 +206,18 @@ void QAttn::forward_cuda_2 half* lora_temp ) { - gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); + if (!post_layernorm) + { + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); + } + else + { + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, temp_state, q_len * batch_size, o_proj->width, o_proj->height, true, temp_dq); + if (layernorm_is_rms) + rms_norm_cuda(temp_state, post_layernorm, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); + else + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); + } apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size); } diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cuh b/exllamav2/exllamav2_ext/cuda/q_attn.cuh index 16a47bd5..8a856b76 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cuh @@ -19,6 +19,8 @@ public: half* layernorm; half* layernorm_bias; + half* post_layernorm; + half* post_layernorm_bias; bool layernorm_is_rms; float norm_epsilon; @@ -55,7 +57,7 @@ public: QAttn ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, QMatrix* _q_proj, @@ -76,7 +78,9 @@ public: bool _has_residual, int _rope_style, half* _q_norm, - half* _k_norm + half* _k_norm, + half* _post_layernorm, + half* _post_layernorm_bias ); ~QAttn(); diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 97c263c9..ba8368b7 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -32,7 +32,9 @@ QMLP::QMLP half* _temp_dq, int _max_rows, bool _act_gelu, - bool _has_residual + bool _has_residual, + half* _post_layernorm, + half* _post_layernorm_bias ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -47,7 +49,9 @@ QMLP::QMLP temp_dq(_temp_dq), max_rows(_max_rows), act_gelu(_act_gelu), - has_residual(_has_residual) + has_residual(_has_residual), + post_layernorm(_post_layernorm), + post_layernorm_bias(_post_layernorm_bias) { } @@ -114,9 +118,23 @@ void QMLP::forward_ kernel<<>>(temp_a, rows, intermediate_size, NULL, 0); } - // Down proj + // Down proj without post_layernorm - gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, !has_residual, temp_dq); + if (!post_layernorm) + { + gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, !has_residual, temp_dq); + } + + // Down proj with post_layernorm + + else + { + gemm_half_q_half_cuda(cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq); + if (layernorm_is_rms) + rms_norm_cuda(temp_state, post_layernorm, x, norm_epsilon, rows, columns, true); + else + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, x, norm_epsilon, rows, columns, true); + } apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, x, lora_temp, rows); } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh index 179965b2..2d8b0907 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh @@ -15,6 +15,8 @@ public: half* layernorm; half* layernorm_bias; + half* post_layernorm; + half* post_layernorm_bias; bool layernorm_is_rms; float norm_epsilon; @@ -40,7 +42,7 @@ public: QMLP ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, QMatrix* _gate, @@ -52,7 +54,9 @@ public: half* _temp_dq, int _max_rows, bool _act_gelu, - bool _has_residual + bool _has_residual, + half* _post_layernorm, + half* _post_layernorm_bias ); ~QMLP(); @@ -108,7 +112,7 @@ public: QMoEMLP ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, half* _gate, diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index eb0d5711..ae8697c3 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -23,7 +23,8 @@ typedef void (*fp_rms_norm_kernel) const float, const float, const int, - const int + const int, + const bool ); template @@ -35,7 +36,8 @@ __global__ void rms_norm_kernel const float epsilon, const float r_dim, const int rows, - const int dim + const int dim, + const bool add_residual ) { int warp_id = threadIdx.x / WARP_SIZE; @@ -103,7 +105,10 @@ __global__ void rms_norm_kernel float w_itemf1 = __half2float(__high2half(w2_)); float n0 = x_itemf0 * w_itemf0 * rmf; float n1 = x_itemf1 * w_itemf1 * rmf; - y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); + if (add_residual) + y_row[column] = __hadd2(y_row[column], __halves2half2(__float2half_rn(n0), __float2half_rn(n1))); + else + y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); } } @@ -138,7 +143,8 @@ void rms_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual ) { dim3 blockDim, gridDim; @@ -151,5 +157,5 @@ void rms_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, y, epsilon, r_dim, rows, dim); + kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual); } diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh index 4cb0fea9..786b1e55 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh @@ -13,7 +13,8 @@ void rms_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual = false ); #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index c4452489..fda3b43d 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -39,7 +39,9 @@ uintptr_t make_q_attn bool has_residual, int rope_style, torch::Tensor q_norm, - torch::Tensor k_norm + torch::Tensor k_norm, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias ) { QMatrix* qm_q_proj = reinterpret_cast (q_q_proj); @@ -48,6 +50,7 @@ uintptr_t make_q_attn QMatrix* qm_o_proj = reinterpret_cast (q_o_proj); TORCH_CHECK_DTYPE_OPT(layernorm, kHalf); + TORCH_CHECK_DTYPE_OPT(post_layernorm, kHalf); if (qm_q_proj && !layernorm.is_meta()) TORCH_CHECK(qm_q_proj->height == layernorm.size(0), "q_proj is wrong shape") if (qm_k_proj && !layernorm.is_meta()) TORCH_CHECK(qm_k_proj->height == layernorm.size(0), "k_proj is wrong shape") @@ -78,7 +81,9 @@ uintptr_t make_q_attn has_residual, rope_style, (half*) q_norm.is_meta() ? NULL : (half*) q_norm.data_ptr(), - (half*) k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr() + (half*) k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr(), + (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr() ); return reinterpret_cast (attn); diff --git a/exllamav2/exllamav2_ext/ext_qattn.h b/exllamav2/exllamav2_ext/ext_qattn.h index dfb300b6..b8c9dc18 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.h +++ b/exllamav2/exllamav2_ext/ext_qattn.h @@ -23,7 +23,9 @@ uintptr_t make_q_attn bool has_residual, int rope_style, torch::Tensor q_norm, - torch::Tensor k_norm + torch::Tensor k_norm, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias ); void free_q_attn diff --git a/exllamav2/exllamav2_ext/ext_qmlp.cpp b/exllamav2/exllamav2_ext/ext_qmlp.cpp index 827ac789..ad62389c 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.cpp +++ b/exllamav2/exllamav2_ext/ext_qmlp.cpp @@ -30,7 +30,9 @@ uintptr_t make_q_mlp torch::Tensor temp_dq, int max_rows, bool act_gelu, - bool has_residual + bool has_residual, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias ) { QMatrix* qm_gate = reinterpret_cast (q_gate); @@ -38,6 +40,7 @@ uintptr_t make_q_mlp QMatrix* qm_down = reinterpret_cast (q_down); TORCH_CHECK_DTYPE_OPT(layernorm, kHalf); + TORCH_CHECK_DTYPE_OPT(post_layernorm, kHalf); if (qm_gate && !layernorm.is_meta()) TORCH_CHECK(qm_gate->height == layernorm.size(0), "gate_proj is wrong shape") if (!layernorm.is_meta()) TORCH_CHECK(qm_up->height == layernorm.size(0), "up_proj is wrong shape") @@ -56,7 +59,9 @@ uintptr_t make_q_mlp (half*) temp_dq.data_ptr(), max_rows, act_gelu, - has_residual + has_residual, + (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr() ); return reinterpret_cast (mlp); diff --git a/exllamav2/exllamav2_ext/ext_qmlp.h b/exllamav2/exllamav2_ext/ext_qmlp.h index 45e7ed1a..e76bd227 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.h +++ b/exllamav2/exllamav2_ext/ext_qmlp.h @@ -14,7 +14,9 @@ uintptr_t make_q_mlp torch::Tensor temp_dq, int max_rows, bool act_gelu, - bool has_residual + bool has_residual, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias ); void free_q_mlp diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 63e4bc2b..18d39409 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -20,7 +20,8 @@ class ExLlamaV2MLP(ExLlamaV2Module): name: str = "MLP" layer_idx: int - post_attention_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None gate_proj: ExLlamaV2Linear | None up_proj: ExLlamaV2Linear | None down_proj: ExLlamaV2Linear | None @@ -56,19 +57,24 @@ def __init__(self, if self.has_norm: if cfg.arch.norm == "layernorm": - self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2) + self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2) + self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None elif cfg.arch.norm == "rmsnorm": - self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2) + self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2) + self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None else: - self.post_attention_layernorm = None + self.pre_layernorm = None + self.post_layernorm = None self.up_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_up, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c) self.down_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_down, cfg.intermediate_size, cfg.hidden_size, self.model.config.arch.mlp_bias, prescale = cfg.scale_depth) self.submodules = [self.up_proj, self.down_proj] - if self.has_norm: - self.submodules += [self.post_attention_layernorm] + if self.pre_layernorm: + self.submodules += [self.pre_layernorm] + if self.post_layernorm: + self.submodules += [self.post_layernorm] if cfg.arch.mlp_gate: self.gate_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_gate, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b) @@ -85,8 +91,10 @@ def numel(self) -> int: if self.model.config.arch.mlp_gate: numel += self.gate_proj.numel() - if self.post_attention_layernorm is not None: - numel += self.post_attention_layernorm.numel() + if self.pre_layernorm is not None: + numel += self.pre_layernorm.numel() + if self.post_layernorm is not None: + numel += self.pre_layernorm.numel() return numel @@ -96,8 +104,10 @@ def load(self): cfg = self.model.config - if self.post_attention_layernorm is not None: - self.post_attention_layernorm.load() + if self.pre_layernorm is not None: + self.pre_layernorm.load() + if self.post_layernorm is not None: + self.post_layernorm.load() if cfg.checkpoint_fused_mlp: w12 = self.load_weight(self.key + cfg.arch.fused_mlp_key_12) @@ -119,16 +129,23 @@ def load(self): device_tensors.begin_scratch_alloc() if self.has_norm: - norm_weight = self.post_attention_layernorm.weight if self.post_attention_layernorm.weight is not None else none_tensor - norm_bias = self.post_attention_layernorm.bias if self.post_attention_layernorm.bias is not None else none_tensor - is_rms = isinstance(self.post_attention_layernorm, ExLlamaV2RMSNorm) - eps = self.post_attention_layernorm.variance_epsilon + norm_weight = self.pre_layernorm.weight if self.pre_layernorm.weight is not None else none_tensor + norm_bias = self.pre_layernorm.bias if self.pre_layernorm.bias is not None else none_tensor + is_rms = isinstance(self.pre_layernorm, ExLlamaV2RMSNorm) + eps = self.pre_layernorm.variance_epsilon else: norm_weight = none_tensor norm_bias = none_tensor is_rms = False eps = 0 + if self.post_layernorm is not None: + post_norm_weight = self.post_layernorm.weight if self.post_layernorm.weight is not None else none_tensor + post_norm_bias = self.post_layernorm.bias if self.post_layernorm.bias is not None else none_tensor + else: + post_norm_weight = none_tensor + post_norm_bias = none_tensor + self.q_handle = ext_c.make_q_mlp(norm_weight, norm_bias, is_rms, @@ -142,7 +159,9 @@ def load(self): device_tensors.get_scratch_slice(self.temp_dq_size()), cfg.max_input_len * cfg.max_batch_size, cfg.arch.mlp_act_func == "gelu", - self.has_residual) + self.has_residual, + post_norm_weight, + post_norm_bias) def unload(self): @@ -151,7 +170,8 @@ def unload(self): ext_c.free_q_mlp(self.q_handle) self.q_handle = None - if self.post_attention_layernorm is not None: self.post_attention_layernorm.unload() + if self.pre_layernorm is not None: self.pre_layernorm.unload() + if self.post_layernorm is not None: self.post_layernorm.unload() if self.gate_proj is not None: self.gate_proj.unload() self.up_proj.unload() self.down_proj.unload() @@ -167,8 +187,10 @@ def weight_footprint(self) -> int: if self.gate_proj is not None: fp += self.gate_proj.weight_footprint() - if self.post_attention_layernorm is not None: - fp += self.post_attention_layernorm.weight_footprint() + if self.pre_layernorm is not None: + fp += self.pre_layernorm.weight_footprint() + if self.post_layernorm is not None: + fp += self.post_layernorm.weight_footprint() return fp @@ -219,8 +241,10 @@ def temp_dq_size(self) -> int: def set_device_idx(self, idx: int): super().set_device_idx(idx) - if self.post_attention_layernorm is not None: - self.post_attention_layernorm.set_device_idx(idx) + if self.pre_layernorm is not None: + self.pre_layernorm.set_device_idx(idx) + if self.post_layernorm is not None: + self.post_layernorm.set_device_idx(idx) if self.gate_proj is not None: self.gate_proj.set_device_idx(idx) self.up_proj.set_device_idx(idx) self.down_proj.set_device_idx(idx) @@ -266,8 +290,8 @@ def forward_torch(self, cfg = self.model.config residual = hidden_states - post_norm = self.post_attention_layernorm.forward(hidden_states) \ - if self.has_norm else hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) \ + if self.pre_layernorm else hidden_states if self.gate_proj is not None: gate = self.gate_proj.forward(post_norm, loras = loras) @@ -286,6 +310,8 @@ def forward_torch(self, y = F.gelu(up) down = self.down_proj.forward(y, loras = loras) + if self.post_layernorm: + down = self.post_layernorm.forward(down) hidden_states = down + residual if self.has_residual else down if intermediates: From 84d00cbbc082beaf7afaf8830358f36c354c1d50 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:29:02 +0200 Subject: [PATCH 14/40] Support pre_attn_scalar config entry --- exllamav2/attn.py | 50 ++++++++++++++++++++++++--------------------- exllamav2/config.py | 3 +++ 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 038a278d..e4c0f90c 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -98,6 +98,8 @@ class ExLlamaV2Attention(ExLlamaV2Module): has_norm: bool has_residual: bool + scaling: float + sliding_window: int class Params: @@ -165,9 +167,9 @@ def get_past_lens(self, device) -> torch.Tensor | None: self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device) return self.past_lens_tensor - def get_attn_mask(self, device) -> torch.Tensor | None: + def get_attn_mask(self, device, force: bool = False) -> torch.Tensor | None: if self.attn_mask is None: - self.attn_mask = self.build_attn_mask(device) + self.attn_mask = self.build_attn_mask(device, force) elif self.attn_mask.device != device: self.attn_mask = safe_move_tensor(self.attn_mask, device) return self.attn_mask @@ -190,9 +192,9 @@ def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_ma attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part) return attn_mask - def build_attn_mask(self, device) -> torch.Tensor | None: + def build_attn_mask(self, device, force: bool = False) -> torch.Tensor | None: assert not self.multi_cache, "Building single mask for multiple caches" - if self.input_mask is None and self.seq_len == 1: return None + if self.input_mask is None and self.seq_len == 1 and not force: return None return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask) def build_attn_masks(self, device) -> torch.Tensor | None: @@ -330,12 +332,10 @@ def __init__(self, if cfg.use_qk_norm: self.submodules += [self.q_norm, self.k_norm] - # if cfg.arch.scale_attn_weights: - # self.unscale_factor = self.layer_idx + 1 - # self.scale_factor = 1 / self.unscale_factor - # else: - self.unscale_factor = 1 - self.scale_factor = 1 + if cfg.query_pre_attn_scalar: + self.scaling = cfg.query_pre_attn_scalar ** (-0.5) + else: + self.scaling = 1 / math.sqrt(cfg.head_dim) def numel(self) -> int: @@ -690,7 +690,7 @@ def forward_paged(self, block_table, None, None, - 1 / math.sqrt(cfg.head_dim), + self.scaling, True, -1, -1, True, @@ -724,31 +724,32 @@ def forward_paged(self, def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): - if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa: + q_states = q_states.transpose(1, 2) + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) - q_states = q_states.transpose(1, 2) - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) + if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) - attn_output = F.scaled_dot_product_attention(q_states, k_states, v_states, attn_mask_lr) + attn_output = F.scaled_dot_product_attention( + q_states, + k_states, + v_states, + attn_mask_lr, + scale = self.scaling + ) else: - q_states = q_states.transpose(1, 2) - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) - k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) k_states = k_states.transpose(-1, -2) attn_weights = torch.matmul(q_states, k_states) - attn_weights *= 1 / math.sqrt(cfg.head_dim) - attn_mask = attn_params.get_attn_mask(attn_weights.device) + attn_weights *= self.scaling if attn_mask is not None: attn_weights = attn_weights + attn_mask attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) @@ -767,6 +768,8 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para k_states, v_states, causal = True + causal = True, + softmax_scale = self.scaling, ) attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) return attn_output @@ -792,7 +795,8 @@ def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_p q_states, k_states, v_states, - attn_bias = LowerTriangularFromBottomRightMask() + attn_bias = LowerTriangularFromBottomRightMask(), + scale = self.scaling ) attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) diff --git a/exllamav2/config.py b/exllamav2/config.py index e2e182c5..2f1901df 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -100,6 +100,7 @@ class ExLlamaV2Config: scale_depth: float scale_emb: float use_qk_norm: bool + query_pre_attn_scalar: float | None final_logit_softcapping: float | None checkpoint_fused_mlp: bool @@ -218,6 +219,8 @@ def prepare(self, no_tensors: bool = False): self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False) + self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None) + # MLP params if self.arch.default_inner_dim_mult is not None: From 66c4a9c8498ce5d292aaff91fd885eaa10c7e3f5 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:36:07 +0200 Subject: [PATCH 15/40] Support (alternating) SWA --- exllamav2/architecture.py | 3 +++ exllamav2/attn.py | 36 +++++++++++++++++++++++++++++++++--- exllamav2/config.py | 3 +++ exllamav2/model.py | 8 +++++++- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 694eba2f..46e87d05 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -116,6 +116,9 @@ def __init__(self, arch_string, read_config): self.norm_key_1_post = None self.norm_key_2_post = None + self.swa = False + self.alternating_swa = False + self.fused_qkv_altpack = False # Mistral diff --git a/exllamav2/attn.py b/exllamav2/attn.py index e4c0f90c..fea744da 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -15,6 +15,7 @@ import math # from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak import torch.nn.functional as F +import inspect # from line_profiler import profile from typing import TYPE_CHECKING @@ -25,6 +26,7 @@ has_flash_attn = False has_flash_attn_with_paged = False +has_flash_attn_with_window = False try: import flash_attn @@ -45,6 +47,9 @@ has_flash_attn = True has_flash_attn_with_paged = True + has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + except ModuleNotFoundError: pass @@ -277,7 +282,8 @@ def __init__(self, key: str, layer_idx: int, has_norm: bool = True, - has_residual: bool = True): + has_residual: bool = True, + sliding_window: int = 0): super().__init__(model, key) @@ -337,6 +343,8 @@ def __init__(self, else: self.scaling = 1 / math.sqrt(cfg.head_dim) + self.sliding_window = sliding_window + def numel(self) -> int: @@ -682,6 +690,8 @@ def forward_paged(self, # block_table = block_table, # causal = True # ) + window_size = -1 if not self.sliding_window else self.sliding_window + attn_output, _ = flash_attn_cuda.fwd_kvcache( q, k_cache, v_cache, k, v, cache_seqlens_a, @@ -692,7 +702,7 @@ def forward_paged(self, None, self.scaling, True, - -1, -1, + window_size, window_size, True, 0, ) @@ -733,6 +743,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) + if self.sliding_window and k_states.shape[2] >= self.sliding_window: + k_states = k_states[:, :, -self.sliding_window:, :] + v_states = v_states[:, :, -self.sliding_window:, :] + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) attn_output = F.scaled_dot_product_attention( q_states, @@ -750,7 +764,13 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para attn_weights = torch.matmul(q_states, k_states) attn_weights *= self.scaling + attn_mask = attn_params.get_attn_mask(attn_weights.device) + if attn_mask is not None: attn_weights = attn_weights + attn_mask + if self.sliding_window and k_states.shape[-1] >= self.sliding_window: + attn_weights = attn_weights[:, :, :, -self.sliding_window:] + v_states = v_states[:, :, -self.sliding_window:, :] + attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) @@ -763,13 +783,20 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + assert has_flash_attn_with_window or not self.sliding_window, \ + "Installed version of flash-attn does not support sliding window" + + flash_kwargs = { + "window_size": (self.sliding_window, self.sliding_window) + } if self.sliding_window else {} + attn_output = flash_attn_func( q_states, k_states, v_states, - causal = True causal = True, softmax_scale = self.scaling, + **flash_kwargs ) attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) return attn_output @@ -777,6 +804,9 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + assert not self.sliding_window, \ + "Sliding window not currently supported for xformers" + # xformers memory_efficient_attention, could be beneficial if your device's architecture is less than sm_80 are almost the same. But the martix operation diff --git a/exllamav2/config.py b/exllamav2/config.py index 2f1901df..c311ac32 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -102,6 +102,7 @@ class ExLlamaV2Config: use_qk_norm: bool query_pre_attn_scalar: float | None final_logit_softcapping: float | None + sliding_window: int checkpoint_fused_mlp: bool @@ -259,6 +260,8 @@ def prepare(self, no_tensors: bool = False): "n_positions"], 2048) self.original_max_seq_len = self.max_seq_len + self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0) + rs = read(read_config, dict, "rope_scaling", None) if rs: scaling_type = rs.get("type", None) diff --git a/exllamav2/model.py b/exllamav2/model.py index a1f8575b..624d7ac4 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -227,7 +227,13 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx) self.modules += [pd] else: - attn = ExLlamaV2Attention(self, layer_key, layer_idx) + if self.config.arch.alternating_swa: + swa = self.config.sliding_window if not bool(layer_idx % 2) else 0 + elif self.config.arch.swa: + swa = self.config.sliding_window + else: + swa = 0 + attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx) else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx) self.modules += [attn, mlp] From 60eb8347b801107369fe6c914fcca15b74dfb095 Mon Sep 17 00:00:00 2001 From: Brian Dashore Date: Fri, 5 Jul 2024 17:56:38 -0400 Subject: [PATCH 16/40] dynamic_async: Properly close the iterator loop on exit (#538) When the close method is called, the generator's iterator loop never actually exited. This is because the condition is not notified meaning the task is still running even though it's signalled to cancel. Therefore, add an extra pass if the task is cancelled and unlock the loop by forcing a notify on close. From there, normal cancellation handling will work. There might be a better way to do this, but this way minimizes the amount of added code and makes the most(?) sense. Signed-off-by: kingbri --- exllamav2/generator/dynamic_async.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/exllamav2/generator/dynamic_async.py b/exllamav2/generator/dynamic_async.py index d9e1fee7..8e006822 100644 --- a/exllamav2/generator/dynamic_async.py +++ b/exllamav2/generator/dynamic_async.py @@ -22,7 +22,9 @@ async def _run_iteration(self): try: while True: async with self.condition: - await self.condition.wait_for(lambda: len(self.jobs) > 0) + # Unlock if there's no jobs or if the parent task is cancelled + await self.condition.wait_for(lambda: len(self.jobs) > 0 or self.iteration_task.cancelled()) + results = self.generator.iterate() for result in results: job = result["job"] @@ -31,6 +33,9 @@ async def _run_iteration(self): if result["eos"]: del self.jobs[job] await asyncio.sleep(0) + except asyncio.CancelledError: + # Silently return on cancel + return except Exception as e: # If the generator throws an exception it won't pertain to any one ongoing job, so push it to all of them for async_job in self.jobs.values(): @@ -48,6 +53,9 @@ async def _notify_condition(self): async def close(self): self.iteration_task.cancel() + + # Force a re-check of the condition to unlock the loop + await self._notify_condition() try: await self.iteration_task except asyncio.CancelledError: From 83f0d19cbd90af535b8e09b8ef9d98b52a677f51 Mon Sep 17 00:00:00 2001 From: Ahmad Fahadh Ilyas <37577369+fahadh4ilyas@users.noreply.github.com> Date: Sat, 6 Jul 2024 08:57:28 +0700 Subject: [PATCH 17/40] make target_modules in lora usable (#534) --- exllamav2/embedding.py | 10 ++++++++-- exllamav2/linear.py | 8 ++++++++ exllamav2/lora.py | 25 +++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index b52c02ac..8b111c8e 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -124,7 +124,10 @@ def forward(self, standard_mask_ = standard_mask[i] input_ids_ = input_ids[i] standard_ids_ = input_ids_[standard_mask_] - standard_embeddings_ = self.embedding(standard_ids_) + if loras is not None and loras[0].embed_tokens is not None: + standard_embeddings_ = loras[0].embed_tokens(standard_ids_) + else: + standard_embeddings_ = self.embedding(standard_ids_) standard_embeddings_ = safe_move_tensor(standard_embeddings_, indexed_embeddings.device) combined_embeddings[i][standard_mask_] = standard_embeddings_ @@ -144,7 +147,10 @@ def forward(self, # Call embedding module if no indexed embeddings else: - hidden_states = self.embedding.forward(hidden_states) + if loras is not None and loras[0].embed_tokens is not None: + hidden_states = loras[0].embed_tokens(hidden_states) + else: + hidden_states = self.embedding(hidden_states) if self.model.config.arch.normalize_embeddings: hidden_states *= self.model.config.hidden_size ** 0.5 diff --git a/exllamav2/linear.py b/exllamav2/linear.py index e0e455e6..9858b3f1 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -241,6 +241,14 @@ def forward(self, # Linear forward + if self.key == 'lm_head' and loras is not None and loras[0].lm_head is not None: + hidden_states_out = loras[0].lm_head(hidden_states) + + if intermediates: + return {"hidden_states": hidden_states_out} + else: + return hidden_states_out + if self.q_handle is not None and not force_recons: output_shape = hidden_states.shape[:-1] + (self.out_features,) diff --git a/exllamav2/lora.py b/exllamav2/lora.py index c4b5ea68..133f01c1 100644 --- a/exllamav2/lora.py +++ b/exllamav2/lora.py @@ -53,6 +53,8 @@ def __init__(self, self.target_modules = {} self.bias_ignored = False self.lora_scaling = lora_scaling + self.embed_tokens = None + self.lm_head = None # Grab relevant items from LoRA config @@ -77,6 +79,29 @@ def __init__(self, tensor = f[key] # Find target + if key.endswith(f'{self.config.arch.lm_head_key}.weight'): + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float16) + elif tensor.dtype == torch.float32: + tensor = tensor.to(torch.float16) + target_module = self.model.modules_dict["lm_head"] + tensor = safe_move_tensor(tensor, target_module.device()) + self.lm_head = torch.nn.Linear(target_module.in_features, tensor.shape[0], bias = False, device = "meta") + self.lm_head.weight = torch.nn.Parameter(tensor, requires_grad=False) + continue + elif key.endswith(f'embed_tokens.weight'): + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float16) + elif tensor.dtype == torch.float32: + tensor = tensor.to(torch.float16) + target_module = self.model.modules_dict["model.embed_tokens"] + tensor = safe_move_tensor(tensor, target_module.device()) + self.embed_tokens = torch.nn.Embedding(tensor.shape[0], self.config.hidden_size, self.config.pad_token_id, device = "meta") + weight = torch.nn.Parameter(tensor, requires_grad=False) + if self.model.config.scale_emb != 1: + weight *= self.model.config.scale_emb + self.embed_tokens.weight = weight + continue i = key.find("model.layers.") if i == -1: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}") From 01ce7bbb6e2c59d16ed6c00c49504fb90a1a3edb Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 07:13:58 +0200 Subject: [PATCH 18/40] Attn logit softcapping (for eager attn) --- exllamav2/attn.py | 5 ++++- exllamav2/config.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index fea744da..40f11b64 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -766,7 +766,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para attn_weights *= self.scaling attn_mask = attn_params.get_attn_mask(attn_weights.device) - if attn_mask is not None: attn_weights = attn_weights + attn_mask + if cfg.attn_logit_softcapping: + ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping) + if attn_mask is not None: + attn_weights = attn_weights + attn_mask if self.sliding_window and k_states.shape[-1] >= self.sliding_window: attn_weights = attn_weights[:, :, :, -self.sliding_window:] v_states = v_states[:, :, -self.sliding_window:, :] diff --git a/exllamav2/config.py b/exllamav2/config.py index c311ac32..f9b32f4d 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -102,6 +102,7 @@ class ExLlamaV2Config: use_qk_norm: bool query_pre_attn_scalar: float | None final_logit_softcapping: float | None + attn_logit_softcapping: float | None sliding_window: int checkpoint_fused_mlp: bool @@ -165,9 +166,9 @@ def prepare(self, no_tensors: bool = False): # Load generation_config.json - self.generation_config_path = os.path.join(self.model_dir, "generation_config.json") - if os.path.exists(self.generation_config_path): - with open(self.generation_config_path, encoding = "utf8") as f: + generation_config_path = os.path.join(self.model_dir, "generation_config.json") + if os.path.exists(generation_config_path): + with open(generation_config_path, encoding = "utf8") as f: gen_config = json.load(f) self.generation_config = {} try: @@ -247,6 +248,7 @@ def prepare(self, no_tensors: bool = False): else: self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers) + self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None) self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None) # Positional embeddings From bfc3cd9cf3eb87f8d738d67e49ef92b48a7f7942 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 07:14:47 +0200 Subject: [PATCH 19/40] Support Gemma2 --- exllamav2/architecture.py | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 46e87d05..271c40d4 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -9,6 +9,10 @@ ["ln_2"]] layer_keys_yi_norms = [["ln1", "input_layernorm"], ["ln2", "post_attention_layernorm"]] +layer_keys_gemma2_norms = [["input_layernorm"], + ["post_attention_layernorm"], + ["pre_feedforward_layernorm"], + ["post_feedforward_layernorm"]] layer_keys_internlm2_norms = [["attention_norm"], ["ffn_norm"]] layer_keys_llama_attn = [["self_attn.q_proj"], @@ -326,6 +330,44 @@ def __init__(self, arch_string, read_config): self.mqa = False self.scale_attn_weights = False + # Gemma2 + + if arch_string == "Gemma2ForCausalLM": + arch_recognized = True + self.layer_keys += \ + layer_keys_gemma2_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.expect_keys += \ + expect_keys_gemma + self.norm_eps_key = "rms_norm_eps" + self.attention_bias_qkv = False + self.attention_bias_o = False + self.mlp_bias = False + self.mlp_gate = True + self.mlp_key_gate = ".mlp.gate_proj" + self.mlp_key_up = ".mlp.up_proj" + self.mlp_key_down = ".mlp.down_proj" + self.mlp_act_func = "gelu" + self.is_moe = False + self.norm = "rmsnorm" + self.lm_head_key = "model.embed_tokens" + self.normalize_embeddings = True + self.norm_key_1 = ".input_layernorm" + self.norm_key_1_post = ".post_attention_layernorm" + self.norm_key_2 = ".pre_feedforward_layernorm" + self.norm_key_2_post = ".post_feedforward_layernorm" + self.norm_constant_bias = 1 + self.parallel_decoder_blocks = False + self.requires_bos = True + self.rope_style = RopeStyle.NEOX + self.keymap = None + self.fused_qkv_key = None + self.mqa = False + self.scale_attn_weights = False + self.pre_post_layernorm = True + self.alternating_swa = True + # StarCoder2 if arch_string == "Starcoder2ForCausalLM": From 0963870252114d92e7f95cb773b32ac86f44bf98 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 07:15:29 +0200 Subject: [PATCH 20/40] Insist on eager attn for Gemma2 (until flash-attn gets support) --- exllamav2/architecture.py | 3 +++ exllamav2/config.py | 21 ++++++++++++++++++++- exllamav2/conversion/convert_exl2.py | 2 +- exllamav2/model_init.py | 4 ++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 271c40d4..b2bf903e 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -123,6 +123,8 @@ def __init__(self, arch_string, read_config): self.swa = False self.alternating_swa = False + self.eager_attn_only = False + self.fused_qkv_altpack = False # Mistral @@ -367,6 +369,7 @@ def __init__(self, arch_string, read_config): self.scale_attn_weights = False self.pre_post_layernorm = True self.alternating_swa = True + self.eager_attn_only = True # StarCoder2 diff --git a/exllamav2/config.py b/exllamav2/config.py index f9b32f4d..e0118c11 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -342,4 +342,23 @@ def prepare(self, no_tensors: bool = False): if not match: raise ValueError(f" ## Could not find {prefix}.* in model") - x = 0 \ No newline at end of file + x = 0 + + + def arch_compat_overrides(self, quiet: bool = False, warn_only = False): + + warnings = [] + + if self.arch.eager_attn_only: + warnings.append(" !! Warning: Architecture currently supports only eager attention") + if not warn_only: + warnings.append(" !! Warning: flash-attn, xformers and SDPA are disabled") + self.no_flash_attn = True + self.no_xformers = True + self.no_sdpa = True + else: + warnings.append(" !! Warning: flash-attn, xformers and SDPA should be disabled for correct inference") + + if not quiet: + for w in warnings: + print(w) diff --git a/exllamav2/conversion/convert_exl2.py b/exllamav2/conversion/convert_exl2.py index 8a1f194f..2f56cf77 100644 --- a/exllamav2/conversion/convert_exl2.py +++ b/exllamav2/conversion/convert_exl2.py @@ -183,8 +183,8 @@ def save_job(): config = ExLlamaV2Config() config.model_dir = job['in_dir'] -config.qkv_embed = False config.prepare() +config.arch_compat_overrides() # Tokenizer diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index 905c3bc0..c40f10e7 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -105,6 +105,10 @@ def init(args, if args.low_mem: config.set_low_mem() if args.load_q4: config.load_in_q4 = True + # Compatibility warnings + + config.arch_compat_overrides(warn_only = True) + # Load model # If --gpu_split auto, return unloaded model. Model must be loaded with model.load_autosplit() supplying cache # created in lazy mode From adebcdbd9d89f9f0167510976cc5ab8261f6198a Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 08:06:14 +0200 Subject: [PATCH 21/40] Use gelu_pytorch_tanh() instead of gelu() --- exllamav2/mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 18d39409..51510d2a 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -298,7 +298,7 @@ def forward_torch(self, if cfg.arch.mlp_act_func == "silu": y = F.silu(gate) elif cfg.arch.mlp_act_func == "gelu": - y = F.gelu(gate) + y = F.gelu(gate, approximate = "tanh") up = self.up_proj.forward(post_norm, loras = loras) y *= up y.clamp_(min = -65504.0, max = 65504.0) @@ -307,7 +307,7 @@ def forward_torch(self, if cfg.arch.mlp_act_func == "silu": y = F.silu(up) elif cfg.arch.mlp_act_func == "gelu": - y = F.gelu(up) + y = F.gelu(up, approximate = "tanh") down = self.down_proj.forward(y, loras = loras) if self.post_layernorm: From 8f5680dfca52215d1c428d58377b3144376dc105 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 08:08:33 +0200 Subject: [PATCH 22/40] Add measurement sanity check --- exllamav2/conversion/measure.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/exllamav2/conversion/measure.py b/exllamav2/conversion/measure.py index f42b6706..9315f916 100644 --- a/exllamav2/conversion/measure.py +++ b/exllamav2/conversion/measure.py @@ -159,6 +159,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p total_numel += module.v_proj.numel() total_numel += module.o_proj.numel() + max_accuracy = 0.0 (q_, k_, v_, o_) = (-1, -1, -1, -1) for (q, k, v, o) in qmaps: @@ -177,6 +178,8 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -187,6 +190,10 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p "o_proj": qjobs[3][o].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + for x in ["k_proj", "v_proj", "o_proj"] + (["q_proj"] if not keep_q else []): if x in quantizers: del quantizers[x] @@ -216,6 +223,7 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa total_numel += module.up_proj.numel() total_numel += module.down_proj.numel() + max_accuracy = 0.0 if has_gate: (g_, u_, d_) = (-1, -1, -1) @@ -234,6 +242,8 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -259,6 +269,8 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -267,6 +279,10 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa "down_proj": qjobs[2][d].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + for x in ["up_proj", "down_proj", "gate_proj"]: if x in quantizers: del quantizers[x] @@ -311,6 +327,7 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att total_numel += sum(module.w3[i].numel() for i in range(num_experts)) total_numel += sum(module.w2[i].numel() for i in range(num_experts)) + max_accuracy = 0.0 (g_, u_, d_) = (-1, -1, -1) for (g, u, d) in qmaps: @@ -328,6 +345,8 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att accuracy = test_error(module, hidden_states, target_states, cache, attn_mask) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -337,6 +356,10 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att "w2": qjobs[2][d].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + return results @@ -515,9 +538,24 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): for i in range(len(hidden_states)): x = hidden_states[i].to("cuda:0") + if torch.isnan(x).any(): + print(" ## Measurement/inference error (2)") + os._exit(1) + if torch.isinf(x).any(): + print(" ## Measurement/inference error (3)") + os._exit(1) + outputs = module.forward(x, cache, attn_params, intermediates = True) target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu" + for k, v in outputs.items(): + if torch.isnan(v).any(): + print(f" ## Measurement/inference error (2): {k}") + os._exit(1) + if torch.isinf(v).any(): + print(f" ## Measurement/inference error (3): {k}") + os._exit(1) + # Hessians if mode == "self_attn": From adefba19732c2eb130795ad1f0eeba5ee4d07f62 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 11:55:23 +0200 Subject: [PATCH 23/40] Optionally clamp hidden states (for Gemma2) --- exllamav2/architecture.py | 2 ++ exllamav2/attn.py | 6 ++++++ exllamav2/mlp.py | 8 ++++++++ exllamav2/rmsnorm.py | 3 +-- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index b2bf903e..066befea 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -124,6 +124,7 @@ def __init__(self, arch_string, read_config): self.alternating_swa = False self.eager_attn_only = False + self.clamp_hidden_states = False self.fused_qkv_altpack = False @@ -370,6 +371,7 @@ def __init__(self, arch_string, read_config): self.pre_post_layernorm = True self.alternating_swa = True self.eager_attn_only = True + self.clamp_hidden_states = True # StarCoder2 diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 40f11b64..d8c508ef 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -973,6 +973,9 @@ def forward(self, pass_lora_temp ) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -1081,6 +1084,9 @@ def forward_torch(self, hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "attn_output": attn_output, diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 51510d2a..f41ea967 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -260,6 +260,8 @@ def forward(self, loras: list[ExLlamaV2Lora] | None = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: + cfg = self.model.config + if self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) @@ -275,6 +277,9 @@ def forward(self, pass_loras, pass_lora_temp) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -314,6 +319,9 @@ def forward_torch(self, down = self.post_layernorm.forward(down) hidden_states = down + residual if self.has_residual else down + if cfg.arch.clamp_hidden_states: + hidden_states = hidden_states.clamp(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "pre_down": y, diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index 68518c5d..fa3b7e29 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -120,8 +120,7 @@ def forward_torch(self, loras = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: - hidden_states[hidden_states == -float('inf')] = -65504.0 - hidden_states[hidden_states == float('inf')] = 65504.0 + hidden_states.clamp_(-65504.0, 65504.0) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) From c294f3436f982634c84fbbeae5f0995d85d517cd Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 11:55:37 +0200 Subject: [PATCH 24/40] Fix model_diff script --- model_diff.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/model_diff.py b/model_diff.py index 9eeae0c7..5238477e 100644 --- a/model_diff.py +++ b/model_diff.py @@ -14,7 +14,7 @@ import pandas, fastparquet import torch import torch.nn.functional as F -from conversion.tokenize import get_tokens +from exllamav2.conversion.tokenize import get_tokens from exllamav2.util import list_live_tensors import gc @@ -47,6 +47,8 @@ config[1].prepare() config[0].max_batch_size = 1 config[1].max_batch_size = 1 +config[0].arch_compat_overrides() +config[1].arch_compat_overrides() model = (ExLlamaV2(config[0]), ExLlamaV2(config[1])) model[0].load(lazy = True) From e7fd8529eac7bae56974a502c4da2f8f3c1002b4 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 11:56:28 +0200 Subject: [PATCH 25/40] Fix scratch space allocation for models where hidden_size > num_attention_heads * head_dim --- exllamav2/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index d8c508ef..02e44da8 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -502,7 +502,7 @@ def scratch_space(self): def temp_state_size(self): cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * cfg.num_attention_heads * cfg.head_dim * 2 + 128 + return cfg.max_input_len * cfg.max_batch_size * max(cfg.num_attention_heads * cfg.head_dim, cfg.hidden_size) * 2 + 128 def temp_q_size(self): From 018687b031f1c7a7e9364221c51b37d3ffb9929c Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 16:25:02 +0200 Subject: [PATCH 26/40] Apply final logit softcap when measuring ppl during quant --- exllamav2/conversion/quantize.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/exllamav2/conversion/quantize.py b/exllamav2/conversion/quantize.py index c6dd3abf..16f2ff23 100644 --- a/exllamav2/conversion/quantize.py +++ b/exllamav2/conversion/quantize.py @@ -21,6 +21,7 @@ import torch.nn.functional as F import gc from exllamav2.conversion.bot_status import print_stage +from exllamav2.ext import exllamav2_ext as ext_c, none_tensor def list_live_tensors(): @@ -470,6 +471,10 @@ def quant(job, save_fn, model): output = module.forward(x, cache, attn_params) if module.padding > 0: output = output[:, :, :-module.padding] + if model.config.final_logit_softcapping: + output = output.contiguous() + ext_c.softcap_(output, model.config.final_logit_softcapping) + logits = output[:, :-1, :] logits = logits.float() + 1e-10 target_ids = cal_ids[i:i+1, 1:].to("cuda:0") From d5b82a497cf8d37168fcc13a241fbde78d329d87 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 17:37:16 +0200 Subject: [PATCH 27/40] Explicitly disable gradients when loading float weights --- exllamav2/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exllamav2/module.py b/exllamav2/module.py index b356a993..bbd393cf 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -126,13 +126,13 @@ def load_weight(self, bias = tensors["bias"].half() if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2: tensor = tensor.T - return nn.Parameter(tensor), nn.Parameter(bias) + return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False) else: tensors = self.load_multi(key, ["weight"]) tensor = tensors["weight"].half() # if self.model.config.arch.orig_weights_transposed: # tensor = tensor.T - return nn.Parameter(tensor) + return nn.Parameter(tensor, requires_grad = False) # No weights found for key @@ -180,7 +180,7 @@ def load_weight_fused(self, tensor = tensor.T tensor = tensor.contiguous().to(self.device()) - res.append(nn.Parameter(tensor)) + res.append(nn.Parameter(tensor, requires_grad = False)) if len(res) == 2: return res[0], res[1] if len(res) == 1: return res[0] From cba8f6c0d2887dfaaea77819e1a2a694a70930c1 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 6 Jul 2024 18:15:07 +0200 Subject: [PATCH 28/40] Add sources to setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 44985fea..b3d13a42 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "exllamav2/exllamav2_ext/ext_rope.cpp", "exllamav2/exllamav2_ext/ext_safetensors.cpp", "exllamav2/exllamav2_ext/ext_sampling.cpp", + "exllamav2/exllamav2_ext/ext_element.cpp", "exllamav2/exllamav2_ext/cuda/h_add.cu", "exllamav2/exllamav2_ext/cuda/h_gemm.cu", "exllamav2/exllamav2_ext/cuda/lora.cu", @@ -58,6 +59,7 @@ "exllamav2/exllamav2_ext/cuda/rope.cu", "exllamav2/exllamav2_ext/cuda/cache.cu", "exllamav2/exllamav2_ext/cuda/util.cu", + "exllamav2/exllamav2_ext/cuda/softcap.cu", "exllamav2/exllamav2_ext/cuda/comp_units/kernel_select.cu", "exllamav2/exllamav2_ext/cuda/comp_units/unit_gptq_1.cu", "exllamav2/exllamav2_ext/cuda/comp_units/unit_gptq_2.cu", From 112304d9bbd2bb1b84e05b6445fddbeec15b0b11 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 7 Jul 2024 02:13:59 +0200 Subject: [PATCH 29/40] Enforce attn env overrides as overrides rather than defaults --- exllamav2/attn.py | 65 ++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 02e44da8..8541657d 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -16,6 +16,7 @@ # from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak import torch.nn.functional as F import inspect +import os # from line_profiler import profile from typing import TYPE_CHECKING @@ -27,47 +28,53 @@ has_flash_attn = False has_flash_attn_with_paged = False has_flash_attn_with_window = False +if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: -try: - import flash_attn - flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] - is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) + try: + import flash_attn + flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] + is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) - if not is_ampere_or_newer_gpu: - print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") + if not is_ampere_or_newer_gpu: + print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") - if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: - from flash_attn import flash_attn_func - has_flash_attn = True + if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: + from flash_attn import flash_attn_func + has_flash_attn = True - if [2, 5, 7] <= flash_attn_ver: - from flash_attn import flash_attn_func, flash_attn_with_kvcache - import flash_attn_2_cuda as flash_attn_cuda + if [2, 5, 7] <= flash_attn_ver: + from flash_attn import flash_attn_func, flash_attn_with_kvcache + import flash_attn_2_cuda as flash_attn_cuda - has_flash_attn = True - has_flash_attn_with_paged = True + has_flash_attn = True + has_flash_attn_with_paged = True - has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters) + has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters) + except ModuleNotFoundError: + pass -except ModuleNotFoundError: - pass has_xformers = False -try: - import xformers.ops as xops - # LowerTriangularFromBottomRightMask was added in xformers version 2.4 - from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask - has_xformers = True -except ModuleNotFoundError: - pass +if 'EXLLAMA_NO_XFORMERS' not in os.environ: + + try: + import xformers.ops as xops + # LowerTriangularFromBottomRightMask was added in xformers version 2.4 + from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask + has_xformers = True + except ModuleNotFoundError: + pass + has_lower_right_sdpa = False -try: - from torch.nn.attention.bias import causal_lower_right - has_lower_right_sdpa = True -except ImportError: - pass +if 'EXLLAMA_NO_SDPA' not in os.environ: + try: + from torch.nn.attention.bias import causal_lower_right + has_lower_right_sdpa = True + except ImportError: + pass + def assert_paged_attn(): global has_flash_attn_with_paged From 91299902ff85e967db8ca6a903476db70dc6e741 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 8 Jul 2024 12:52:02 +0200 Subject: [PATCH 30/40] Add FP32 residual mode --- exllamav2/architecture.py | 9 +- exllamav2/attn.py | 7 +- exllamav2/embedding.py | 16 ++- exllamav2/exllamav2_ext/cuda/q_attn.cu | 24 ++-- exllamav2/exllamav2_ext/cuda/q_attn.cuh | 8 +- exllamav2/exllamav2_ext/cuda/q_mlp.cu | 22 ++-- exllamav2/exllamav2_ext/cuda/q_mlp.cuh | 6 +- exllamav2/exllamav2_ext/cuda/rms_norm.cu | 145 ++++++++++++++++------ exllamav2/exllamav2_ext/cuda/rms_norm.cuh | 8 +- exllamav2/exllamav2_ext/ext_norm.cpp | 13 +- exllamav2/exllamav2_ext/ext_qattn.cpp | 13 +- exllamav2/exllamav2_ext/ext_qattn.h | 3 +- exllamav2/exllamav2_ext/ext_qmlp.cpp | 11 +- exllamav2/exllamav2_ext/ext_qmlp.h | 3 +- exllamav2/layernorm.py | 2 + exllamav2/lora.py | 5 + exllamav2/mlp.py | 9 +- exllamav2/rmsnorm.py | 16 ++- 18 files changed, 221 insertions(+), 99 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 066befea..39481858 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -125,6 +125,7 @@ def __init__(self, arch_string, read_config): self.eager_attn_only = False self.clamp_hidden_states = False + self.residual_stream_fp32 = False self.fused_qkv_altpack = False @@ -371,7 +372,8 @@ def __init__(self, arch_string, read_config): self.pre_post_layernorm = True self.alternating_swa = True self.eager_attn_only = True - self.clamp_hidden_states = True + self.clamp_hidden_states = False + self.residual_stream_fp32 = True # StarCoder2 @@ -740,6 +742,11 @@ def __init__(self, arch_string, read_config): self.expect_keys.remove(["lm_head"]) self.lm_head_key = "model.embed_tokens" + # Sanity checks + + if self.residual_stream_fp32: + assert self.norm_key_1_post and self.norm_key_2_post, \ + "FP32 residual stream only implement for arch with post layernorms" def make_fused_mlp(self): diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 8541657d..15b465a2 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -449,6 +449,7 @@ def load(self): k_norm, post_norm_weight, post_norm_bias, + cfg.arch.residual_stream_fp32 ) @@ -1085,13 +1086,15 @@ def forward_torch(self, # Post layernorm if self.post_layernorm: - attn_proj = self.post_layernorm.forward(attn_proj) + attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = cfg.arch.residual_stream_fp32) # Add residual connection hidden_states = (attn_proj + residual) if self.has_residual else attn_proj - if cfg.arch.clamp_hidden_states: + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + elif cfg.arch.clamp_hidden_states: hidden_states.clamp_(-65504, 65504) if intermediates: diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index 8b111c8e..411dccf9 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -89,6 +89,8 @@ def forward(self, loras = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: + cfg = self.model.config + # If input IDs contain negative values, assume they are padding tokens from a model with not pad_token_id # defined @@ -111,7 +113,7 @@ def forward(self, # Create combined tensor on the target device batch_size, seq_len = input_ids.shape - hidden_size = self.model.config.hidden_size + hidden_size = cfg.hidden_size combined_embeddings = torch.empty(batch_size, seq_len, hidden_size, device = indexed_embeddings.device, dtype = indexed_embeddings.dtype) @@ -133,8 +135,10 @@ def forward(self, # Normalization - if self.model.config.arch.normalize_embeddings: - combined_embeddings *= self.model.config.hidden_size ** 0.5 + if cfg.arch.residual_stream_fp32: + combined_embeddings = combined_embeddings.float() + if cfg.arch.normalize_embeddings: + combined_embeddings *= cfg.hidden_size ** 0.5 # Extract indexed embeddings and insert in-place @@ -152,8 +156,10 @@ def forward(self, else: hidden_states = self.embedding(hidden_states) - if self.model.config.arch.normalize_embeddings: - hidden_states *= self.model.config.hidden_size ** 0.5 + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + if cfg.arch.normalize_embeddings: + hidden_states *= cfg.hidden_size ** 0.5 if intermediates: return {"hidden_states": hidden_states} diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cu b/exllamav2/exllamav2_ext/cuda/q_attn.cu index d536ce04..5a165fcc 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cu +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cu @@ -95,7 +95,8 @@ QAttn::QAttn half* _q_norm, half* _k_norm, half* _post_layernorm, - half* _post_layernorm_bias + half* _post_layernorm_bias, + bool _residual_fp32 ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -121,7 +122,8 @@ QAttn::QAttn q_norm(_q_norm), k_norm(_k_norm), post_layernorm(_post_layernorm), - post_layernorm_bias(_post_layernorm_bias) + post_layernorm_bias(_post_layernorm_bias), + residual_fp32(_residual_fp32) { } @@ -132,7 +134,7 @@ QAttn::~QAttn() void QAttn::forward_cuda_1 ( cublasHandle_t cublas_handle, - half* x, + void* x, int batch_size, int q_len, int past_len, @@ -146,14 +148,14 @@ void QAttn::forward_cuda_1 half* lora_temp ) { - half* norm_state = x; + half* norm_state = (half*) x; if (layernorm) { if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, q_len * batch_size, hidden_size); + rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, q_len * batch_size, hidden_size, false, residual_fp32, false); else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, q_len * batch_size, hidden_size); + layer_norm_cuda((half*)x, layernorm, layernorm_bias, temp_state, norm_epsilon, q_len * batch_size, hidden_size); norm_state = temp_state; } @@ -199,7 +201,7 @@ void QAttn::forward_cuda_2 ( cublasHandle_t cublas_handle, const half* attn_output, - half* hidden_state, + void* hidden_state, int q_len, int batch_size, const std::vector& loras, @@ -208,16 +210,16 @@ void QAttn::forward_cuda_2 { if (!post_layernorm) { - gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, (half*) hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); } else { gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, temp_state, q_len * batch_size, o_proj->width, o_proj->height, true, temp_dq); if (layernorm_is_rms) - rms_norm_cuda(temp_state, post_layernorm, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); + rms_norm_cuda(temp_state, post_layernorm, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true, false, residual_fp32); else - layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, (half*) hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); } - apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size); + apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, (half*) hidden_state, lora_temp, q_len * batch_size); } diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cuh b/exllamav2/exllamav2_ext/cuda/q_attn.cuh index 8a856b76..da9abd43 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cuh @@ -52,6 +52,7 @@ public: std::unordered_map> o_proj_lora; bool has_residual; + bool residual_fp32; int rope_style; QAttn @@ -80,7 +81,8 @@ public: half* _q_norm, half* _k_norm, half* _post_layernorm, - half* _post_layernorm_bias + half* _post_layernorm_bias, + bool _residual_fp32 ); ~QAttn(); @@ -88,7 +90,7 @@ public: void forward_cuda_1 ( cublasHandle_t cublas_handle, - half* x, + void* x, int batch_size, int q_len, int past_len, @@ -106,7 +108,7 @@ public: ( cublasHandle_t cublas_handle, const half* attn_output, - half* hidden_state, + void* hidden_state, int q_len, int batch_size, const std::vector& loras, diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index ba8368b7..810aebdb 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -34,7 +34,8 @@ QMLP::QMLP bool _act_gelu, bool _has_residual, half* _post_layernorm, - half* _post_layernorm_bias + half* _post_layernorm_bias, + bool _residual_fp32 ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -51,7 +52,8 @@ QMLP::QMLP act_gelu(_act_gelu), has_residual(_has_residual), post_layernorm(_post_layernorm), - post_layernorm_bias(_post_layernorm_bias) + post_layernorm_bias(_post_layernorm_bias), + residual_fp32(_residual_fp32) { } @@ -61,7 +63,7 @@ QMLP::~QMLP() { void QMLP::forward_ ( cublasHandle_t cublas_handle, - half* x, + void* x, int rows, int columns, const std::vector& loras, @@ -81,14 +83,14 @@ void QMLP::forward_ // Layernorm - half* norm_state = x; + half* norm_state = (half*) x; if (layernorm) { if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns); + rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns, false, residual_fp32, false); else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); + layer_norm_cuda((half*) x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); norm_state = temp_state; } @@ -122,7 +124,7 @@ void QMLP::forward_ if (!post_layernorm) { - gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, !has_residual, temp_dq); + gemm_half_q_half_cuda(cublas_handle, temp_a, down, (half*) x, rows, columns, intermediate_size, !has_residual, temp_dq); } // Down proj with post_layernorm @@ -131,12 +133,12 @@ void QMLP::forward_ { gemm_half_q_half_cuda(cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq); if (layernorm_is_rms) - rms_norm_cuda(temp_state, post_layernorm, x, norm_epsilon, rows, columns, true); + rms_norm_cuda(temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32); else - layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, x, norm_epsilon, rows, columns, true); + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, (half*) x, norm_epsilon, rows, columns, true); } - apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, x, lora_temp, rows); + apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, (half*) x, lora_temp, rows); } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh index 2d8b0907..bb56d36e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh @@ -38,6 +38,7 @@ public: bool act_gelu; bool has_residual; + bool residual_fp32; QMLP ( @@ -56,7 +57,8 @@ public: bool _act_gelu, bool _has_residual, half* _post_layernorm, - half* _post_layernorm_bias + half* _post_layernorm_bias, + bool _residual_fp32 ); ~QMLP(); @@ -64,7 +66,7 @@ public: void forward_ ( cublasHandle_t cublas_handle, - half* x, + void* x, int rows, int columns, const std::vector& loras, diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index ae8697c3..f72bda11 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -17,34 +17,36 @@ typedef void (*fp_rms_norm_kernel) ( + const void*, const half*, - const half*, - half*, + void*, const float, const float, const int, const int, + const bool, + const bool, const bool ); template __global__ void rms_norm_kernel ( - const half* __restrict__ x, + const void* __restrict__ x, const half* __restrict__ w, - half* __restrict__ y, + void* __restrict__ y, const float epsilon, const float r_dim, const int rows, const int dim, - const bool add_residual + const bool add_residual, + const bool input_fp32, + const bool output_fp32 ) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int row = blockIdx.x; - const half2* x_row = (const half2*) (x + row * dim); - half2* y_row = (half2*) (y + row * dim); const half2* w2 = (const half2*) w; // Compute sum of squares for each block @@ -52,21 +54,45 @@ __global__ void rms_norm_kernel float sum = 0.0f; float itemf[blocks_per_warp][2]; - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) + if (!input_fp32) + { + const half2* x_row = (const half2*) (((half*)x) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) break; + + half2 x2 = x_row[column]; + float f0 = __half2float(__low2half(x2)); + float f1 = __half2float(__high2half(x2)); + f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f)); + f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f)); + itemf[i][0] = f0; + itemf[i][1] = f1; + sum = fma(f0, f0, sum); + sum = fma(f1, f1, sum); + } + } + else { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim / 2) break; - - half2 x2 = x_row[column]; - float f0 = __half2float(__low2half(x2)); - float f1 = __half2float(__high2half(x2)); - f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f)); - f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f)); - itemf[i][0] = f0; - itemf[i][1] = f1; - sum = fma(f0, f0, sum); - sum = fma(f1, f1, sum); + const float2* x_row = (const float2*) (((float*)x) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) break; + + float2 x2 = x_row[column]; + float f0 = x2.x; + float f1 = x2.y; + itemf[i][0] = f0; + itemf[i][1] = f1; + sum = fma(f0, f0, sum); + sum = fma(f1, f1, sum); + } } // Shuffle to sum across lanes @@ -92,23 +118,58 @@ __global__ void rms_norm_kernel // Normalize x, scaling by w - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) + if (!output_fp32) + { + half2* y_row = (half2*) (((half*)y) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) return; + half2 w2_ = w2[column]; + + float x_itemf0 = itemf[i][0]; + float x_itemf1 = itemf[i][1]; + float w_itemf0 = __half2float(__low2half(w2_)); + float w_itemf1 = __half2float(__high2half(w2_)); + float n0 = x_itemf0 * w_itemf0 * rmf; + float n1 = x_itemf1 * w_itemf1 * rmf; + if (add_residual) + y_row[column] = __hadd2(y_row[column], __halves2half2(__float2half_rn(n0), __float2half_rn(n1))); + else + y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); + } + } + else { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim / 2) return; - half2 w2_ = w2[column]; - - float x_itemf0 = itemf[i][0]; - float x_itemf1 = itemf[i][1]; - float w_itemf0 = __half2float(__low2half(w2_)); - float w_itemf1 = __half2float(__high2half(w2_)); - float n0 = x_itemf0 * w_itemf0 * rmf; - float n1 = x_itemf1 * w_itemf1 * rmf; - if (add_residual) - y_row[column] = __hadd2(y_row[column], __halves2half2(__float2half_rn(n0), __float2half_rn(n1))); - else - y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); + float2* y_row = (float2*) (((float*)y) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) return; + half2 w2_ = w2[column]; + + float x_itemf0 = itemf[i][0]; + float x_itemf1 = itemf[i][1]; + float w_itemf0 = __half2float(__low2half(w2_)); + float w_itemf1 = __half2float(__high2half(w2_)); + float n0 = x_itemf0 * w_itemf0 * rmf; + float n1 = x_itemf1 * w_itemf1 * rmf; + if (add_residual) + { + float2 y2 = y_row[column]; + y2.x += n0; + y2.y += n1; + y_row[column] = y2; + } + else + { + y_row[column] = make_float2(n0, n1); + } + } } } @@ -138,13 +199,15 @@ fp_rms_norm_kernel pick_rms_norm_kernel(const int blocks_per_warp) void rms_norm_cuda ( - const half* x, + const void* x, const half* w, - half* y, + void* y, const float epsilon, const int rows, const int dim, - const bool add_residual + const bool add_residual, + const bool input_fp32, + const bool output_fp32 ) { dim3 blockDim, gridDim; @@ -157,5 +220,5 @@ void rms_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual); + kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32); } diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh index 786b1e55..2168c0f8 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh @@ -8,13 +8,15 @@ void rms_norm_cuda ( - const half* x, + const void* x, const half* w, - half* y, + void* y, const float epsilon, const int rows, const int dim, - const bool add_residual = false + const bool add_residual = false, + const bool input_fp32 = false, + const bool output_fp32 = false ); #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_norm.cpp b/exllamav2/exllamav2_ext/ext_norm.cpp index c424491e..5b7377c6 100644 --- a/exllamav2/exllamav2_ext/ext_norm.cpp +++ b/exllamav2/exllamav2_ext/ext_norm.cpp @@ -28,9 +28,9 @@ void rms_norm float epsilon ) { - TORCH_CHECK_DTYPE(x, kHalf); + bool input_fp32 = x.dtype() == torch::kFloat; + bool output_fp32 = y.dtype() == torch::kFloat; TORCH_CHECK_DTYPE(w, kHalf); - TORCH_CHECK_DTYPE(y, kHalf); TORCH_CHECK_SHAPES(x, 1, w, 0, 1); TORCH_CHECK_SHAPES(x, 0, y, 0, 1); TORCH_CHECK_SHAPES(x, 1, y, 1, 1); @@ -42,12 +42,15 @@ void rms_norm rms_norm_cuda ( - (half*) x.data_ptr(), + (void*) x.data_ptr(), (half*) w.data_ptr(), - (half*) y.data_ptr(), + (void*) y.data_ptr(), epsilon, rows, - dim + dim, + false, + input_fp32, + output_fp32 ); } diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index fda3b43d..a9351a09 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -41,7 +41,8 @@ uintptr_t make_q_attn torch::Tensor q_norm, torch::Tensor k_norm, torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias + torch::Tensor post_layernorm_bias, + bool residual_fp32 ) { QMatrix* qm_q_proj = reinterpret_cast (q_q_proj); @@ -83,7 +84,8 @@ uintptr_t make_q_attn (half*) q_norm.is_meta() ? NULL : (half*) q_norm.data_ptr(), (half*) k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr(), (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), - (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr() + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr(), + residual_fp32 ); return reinterpret_cast (attn); @@ -116,7 +118,9 @@ void q_attn_forward_1 ) { QAttn* attn = reinterpret_cast (q_attn); - TORCH_CHECK_DTYPE(x, kHalf); + if (attn->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } + TORCH_CHECK_DTYPE_OPT(past_lens, kInt); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); @@ -152,7 +156,8 @@ void q_attn_forward_2 ) { QAttn* attn = reinterpret_cast (q_attn); - TORCH_CHECK_DTYPE(x, kHalf); + if (attn->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); diff --git a/exllamav2/exllamav2_ext/ext_qattn.h b/exllamav2/exllamav2_ext/ext_qattn.h index b8c9dc18..938e90b1 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.h +++ b/exllamav2/exllamav2_ext/ext_qattn.h @@ -25,7 +25,8 @@ uintptr_t make_q_attn torch::Tensor q_norm, torch::Tensor k_norm, torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias + torch::Tensor post_layernorm_bias, + bool residual_fp32 ); void free_q_attn diff --git a/exllamav2/exllamav2_ext/ext_qmlp.cpp b/exllamav2/exllamav2_ext/ext_qmlp.cpp index ad62389c..d089a5c0 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.cpp +++ b/exllamav2/exllamav2_ext/ext_qmlp.cpp @@ -32,7 +32,8 @@ uintptr_t make_q_mlp bool act_gelu, bool has_residual, torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias + torch::Tensor post_layernorm_bias, + bool residual_fp32 ) { QMatrix* qm_gate = reinterpret_cast (q_gate); @@ -61,7 +62,8 @@ uintptr_t make_q_mlp act_gelu, has_residual, (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), - (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr() + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr(), + residual_fp32 ); return reinterpret_cast (mlp); @@ -85,7 +87,8 @@ void q_mlp_forward_ ) { QMLP* mlp = reinterpret_cast (q_mlp); - TORCH_CHECK_DTYPE(x, kHalf); + if (mlp->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); @@ -98,7 +101,7 @@ void q_mlp_forward_ mlp->forward_ ( at::cuda::getCurrentCUDABlasHandle(), - (half*) x.data_ptr(), + (void*) x.data_ptr(), rows, dim, loras, diff --git a/exllamav2/exllamav2_ext/ext_qmlp.h b/exllamav2/exllamav2_ext/ext_qmlp.h index e76bd227..3269bfae 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.h +++ b/exllamav2/exllamav2_ext/ext_qmlp.h @@ -16,7 +16,8 @@ uintptr_t make_q_mlp bool act_gelu, bool has_residual, torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias + torch::Tensor post_layernorm_bias, + bool residual_fp32 ); void free_q_mlp diff --git a/exllamav2/layernorm.py b/exllamav2/layernorm.py index 39209e6f..112ee8f0 100644 --- a/exllamav2/layernorm.py +++ b/exllamav2/layernorm.py @@ -100,6 +100,7 @@ def forward(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, # TODO: **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: output_shape = hidden_states.shape @@ -126,6 +127,7 @@ def forward_torch(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, # TODO: **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: hidden_states = self.layernorm(hidden_states) diff --git a/exllamav2/lora.py b/exllamav2/lora.py index 133f01c1..1dd99cbe 100644 --- a/exllamav2/lora.py +++ b/exllamav2/lora.py @@ -56,6 +56,11 @@ def __init__(self, self.embed_tokens = None self.lm_head = None + # Compatibility check + + assert not self.model.config.arch.residual_stream_fp32, \ + "LoRAs not (yet) supported for models with FP32 residual stream" + # Grab relevant items from LoRA config with open(lora_config_path, encoding = "utf8") as f: diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index f41ea967..58a644aa 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -161,7 +161,8 @@ def load(self): cfg.arch.mlp_act_func == "gelu", self.has_residual, post_norm_weight, - post_norm_bias) + post_norm_bias, + cfg.arch.residual_stream_fp32) def unload(self): @@ -316,10 +317,12 @@ def forward_torch(self, down = self.down_proj.forward(y, loras = loras) if self.post_layernorm: - down = self.post_layernorm.forward(down) + down = self.post_layernorm.forward(down, output_fp32 = cfg.arch.residual_stream_fp32) hidden_states = down + residual if self.has_residual else down - if cfg.arch.clamp_hidden_states: + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + elif cfg.arch.clamp_hidden_states: hidden_states = hidden_states.clamp(-65504, 65504) if intermediates: diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index fa3b7e29..77d13a48 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -97,11 +97,17 @@ def forward(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: output_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - norm = torch.empty_like(hidden_states) + + if not output_fp32: + norm = torch.empty_like(hidden_states, dtype = torch.half) + else: + norm = torch.empty_like(hidden_states, dtype = torch.float) + ext_c.rms_norm(hidden_states, self.weight, norm, self.variance_epsilon) hidden_states = norm.view(output_shape) @@ -118,13 +124,17 @@ def forward_torch(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: - hidden_states.clamp_(-65504.0, 65504.0) + # hidden_states.clamp_(-65504.0, 65504.0) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = hidden_states.to(self.weight.dtype) + + if not output_fp32: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states *= self.weight if intermediates: From 587b7410be90b560f76767923385e0c3495981ee Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 9 Jul 2024 07:33:19 +0200 Subject: [PATCH 31/40] Update examples --- examples/dynamic_gen.py | 2 ++ examples/inference.py | 1 + examples/inference_async.py | 1 + examples/inference_banned_strings.py | 1 + examples/inference_cfg.py | 1 + examples/inference_dedup.py | 1 + examples/inference_json.py | 1 + examples/inference_lora.py | 1 + examples/inference_speculative.py | 2 ++ examples/inference_stream.py | 1 + 10 files changed, 12 insertions(+) diff --git a/examples/dynamic_gen.py b/examples/dynamic_gen.py index ec0e3192..bfeb9d96 100644 --- a/examples/dynamic_gen.py +++ b/examples/dynamic_gen.py @@ -136,6 +136,7 @@ def main(): if use_draft_model: draft_config = ExLlamaV2Config(draft_model_dir) + draft_config.arch_compat_overrides() draft_model = ExLlamaV2(draft_config) draft_cache = ExLlamaV2Cache( @@ -155,6 +156,7 @@ def main(): # 2048, which will also be the limit of the chunk size for prefill used by the dynamic generator. config = ExLlamaV2Config(model_dir) + config.arch_compat_overrides() config.max_input_len = max_chunk_size config.max_attention_size = max_chunk_size ** 2 model = ExLlamaV2(config) diff --git a/examples/inference.py b/examples/inference.py index 0353d45c..9ac28d63 100644 --- a/examples/inference.py +++ b/examples/inference.py @@ -7,6 +7,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_async.py b/examples/inference_async.py index 94629195..c12c6408 100644 --- a/examples/inference_async.py +++ b/examples/inference_async.py @@ -9,6 +9,7 @@ async def main(): model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) + config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_banned_strings.py b/examples/inference_banned_strings.py index c6ad1885..7648b0e6 100644 --- a/examples/inference_banned_strings.py +++ b/examples/inference_banned_strings.py @@ -9,6 +9,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/6.0bpw/" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_cfg.py b/examples/inference_cfg.py index 4ecb5452..b1c86e71 100644 --- a/examples/inference_cfg.py +++ b/examples/inference_cfg.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_dedup.py b/examples/inference_dedup.py index 335f2289..bd39fd6a 100644 --- a/examples/inference_dedup.py +++ b/examples/inference_dedup.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_json.py b/examples/inference_json.py index b6e7608d..eb53262f 100644 --- a/examples/inference_json.py +++ b/examples/inference_json.py @@ -13,6 +13,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_lora.py b/examples/inference_lora.py index 310f0213..33292fee 100644 --- a/examples/inference_lora.py +++ b/examples/inference_lora.py @@ -7,6 +7,7 @@ model_dir = "/mnt/str/models/llama2-7b-exl2/5.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_speculative.py b/examples/inference_speculative.py index 51ea311b..53123dbd 100644 --- a/examples/inference_speculative.py +++ b/examples/inference_speculative.py @@ -12,12 +12,14 @@ draft_model_dir = "/mnt/str/models/qwen2-1.5b-instruct-exl2/4.0bpw" draft_config = ExLlamaV2Config(draft_model_dir) +draft_config.arch_compat_overrides() draft_model = ExLlamaV2(draft_config) draft_cache = ExLlamaV2Cache(draft_model, max_seq_len = total_cache_tokens, lazy = True) draft_model.load_autosplit(draft_cache, progress = True) model_dir = "/mnt/str/models/qwen2-72b-instruct-exl2/6.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = total_cache_tokens, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_stream.py b/examples/inference_stream.py index 3f9cfa80..ce94bc4d 100644 --- a/examples/inference_stream.py +++ b/examples/inference_stream.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) From 052f06048d46f08f9aecded870628d18df43b54d Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 9 Jul 2024 07:34:04 +0200 Subject: [PATCH 32/40] Dynamic Gen: compute max_batch_size automatically by default --- exllamav2/generator/dynamic.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index b86c65bb..14c68858 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -239,10 +239,10 @@ def __init__( model: ExLlamaV2, cache: ExLlamaV2CacheBase, tokenizer: ExLlamaV2Tokenizer, - max_batch_size: int = 16, + max_batch_size: int = None, max_seq_len: int | None = None, max_chunk_size: int | None = None, - max_q_size: int = 16, + max_q_size: int = 8, draft_model: ExLlamaV2 | None = None, draft_cache: ExLlamaV2CacheBase | None = None, num_draft_tokens: int = 4, @@ -268,7 +268,7 @@ def __init__( :param max_batch_size: The maximum number of sequences to process in parallel. The generator will also limit this - dynamically considering the available cache space. + dynamically considering the available cache space. Specify None to calculate automatically :param max_seq_len: Maximum length of each individual sequence. Defaults to the model's max_seq_len. @@ -325,7 +325,13 @@ def __init__( self.draft_model = draft_model self.draft_cache = draft_cache - self.num_draft_tokens = num_draft_tokens if (draft_model or use_ngram_draft) else 0 + + if draft_model or use_ngram_draft: + assert num_draft_tokens <= max_q_size, \ + "num_draft_tokens cannot be larger than max_q_size." + self.num_draft_tokens = num_draft_tokens + else: + self.num_draft_tokens = 0 if draft_model: assert draft_cache is not None, \ @@ -344,12 +350,16 @@ def __init__( assert not isinstance(cache, ExLlamaV2Cache_8bit), \ "Dynamic generator does not currently work with 8-bit cache. Use either FP16 or Q4." - model_max_q = cfg.max_batch_size * cfg.max_input_len - req_max_q = max_q_size * max_batch_size - assert req_max_q <= model_max_q, \ - f"Model has max_batch_size * max_input_len = {cfg.max_batch_size} * {cfg.max_input_len} tokens, " + \ - f"generator requires max_batch_size * max_q_size = {max_batch_size} * {max_q_size} tokens." - self.max_batch_size = max_batch_size + if not max_batch_size: + max_batch_size = cfg.max_input_len // max_q_size + self.max_batch_size = max_batch_size + else: + model_max_q = cfg.max_batch_size * cfg.max_input_len + req_max_q = max_q_size * max_batch_size + assert req_max_q <= model_max_q, \ + f"Model has max_batch_size * max_input_len = {cfg.max_batch_size} * {cfg.max_input_len} tokens, " + \ + f"generator requires max_batch_size * max_q_size = {max_batch_size} * {max_q_size} tokens." + self.max_batch_size = max_batch_size if max_seq_len is not None: assert max_seq_len <= model.config.max_seq_len, \ From f121a1ff1dbacea31ed847e4d8d2621fcab909dc Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 9 Jul 2024 07:35:34 +0200 Subject: [PATCH 33/40] Support softcap via flash-attn when available --- exllamav2/architecture.py | 2 - exllamav2/attn.py | 86 +++++++++++++++++++++++---------------- exllamav2/config.py | 31 ++++++++++++++ exllamav2/model_init.py | 4 +- 4 files changed, 84 insertions(+), 39 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 39481858..30d94697 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -371,8 +371,6 @@ def __init__(self, arch_string, read_config): self.scale_attn_weights = False self.pre_post_layernorm = True self.alternating_swa = True - self.eager_attn_only = True - self.clamp_hidden_states = False self.residual_stream_fp32 = True # StarCoder2 diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 15b465a2..56ce3860 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -28,6 +28,7 @@ has_flash_attn = False has_flash_attn_with_paged = False has_flash_attn_with_window = False +has_flash_attn_with_softcap = False if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: try: @@ -44,12 +45,14 @@ if [2, 5, 7] <= flash_attn_ver: from flash_attn import flash_attn_func, flash_attn_with_kvcache - import flash_attn_2_cuda as flash_attn_cuda + # import flash_attn_2_cuda as flash_attn_cuda has_flash_attn = True has_flash_attn_with_paged = True - has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters) + signature = list(inspect.signature(flash_attn_func).parameters) + has_flash_attn_with_window = "window_size" in signature + has_flash_attn_with_softcap = "softcap" in signature except ModuleNotFoundError: pass @@ -688,32 +691,31 @@ def forward_paged(self, if cache.q_block == 1: cache.get_kv_state(self.layer_idx, batch_size, 0, attn_params.max_cache_seqlen, page_size, cache_seqlens, block_table) - # attn_output = flash_attn_with_kvcache( - # q = q, - # k = k, - # v = v, - # k_cache = k_cache, - # v_cache = v_cache, - # cache_seqlens = cache_seqlens_a, - # block_table = block_table, - # causal = True - # ) - window_size = -1 if not self.sliding_window else self.sliding_window - - attn_output, _ = flash_attn_cuda.fwd_kvcache( - q, k_cache, v_cache, k, v, - cache_seqlens_a, - None, None, - None, - block_table, - None, - None, - self.scaling, - True, - window_size, window_size, - True, - 0, + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping + + attn_output = flash_attn_with_kvcache( + q = q, + k = k, + v = v, + k_cache = k_cache, + v_cache = v_cache, + cache_seqlens = cache_seqlens_a, + block_table = block_table, + causal = True, + softmax_scale = self.scaling, + **flash_kwargs ) + attn_output = attn_output.view((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) @@ -746,7 +748,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para k_states = k_states.transpose(1, 2) v_states = v_states.transpose(1, 2) - if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa: + # SDPA + + if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) @@ -764,6 +768,8 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para scale = self.scaling ) + # Matmul attn + else: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) @@ -794,12 +800,17 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): - assert has_flash_attn_with_window or not self.sliding_window, \ - "Installed version of flash-attn does not support sliding window" - - flash_kwargs = { - "window_size": (self.sliding_window, self.sliding_window) - } if self.sliding_window else {} + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping attn_output = flash_attn_func( q_states, @@ -815,8 +826,11 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): - assert not self.sliding_window, \ - "Sliding window not currently supported for xformers" + # assert not self.sliding_window, \ + # "Sliding window not currently supported for xformers" + + # assert not cfg.attn_logit_softcapping, \ + # "Softcap not yet supported for xformers" # xformers memory_efficient_attention, could be beneficial if your device's architecture is less than Date: Tue, 9 Jul 2024 07:42:34 +0200 Subject: [PATCH 34/40] Fix warnings --- exllamav2/architecture.py | 2 +- exllamav2/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 30d94697..9bb7eca3 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -744,7 +744,7 @@ def __init__(self, arch_string, read_config): if self.residual_stream_fp32: assert self.norm_key_1_post and self.norm_key_2_post, \ - "FP32 residual stream only implement for arch with post layernorms" + "FP32 residual stream only implemented for arch with post layernorms" def make_fused_mlp(self): diff --git a/exllamav2/config.py b/exllamav2/config.py index cad66be0..dc7f199c 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -387,7 +387,7 @@ def arch_compat_overrides(self, quiet: bool = False, warn_only = False): warnings.append(" !! Warning: model requires SWA, not supported in xformers") disable = True if disable and not warn_only: - warnings.append(" !! Warning: disabling flash-attn") + warnings.append(" !! Warning: disabling xformers") self.no_xformers = True if not quiet: From 1e31fbf5d322600491f568554a2b09c89a7a0611 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:05:39 +0200 Subject: [PATCH 35/40] HumanEval: add Gemma template --- eval/humaneval.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/eval/humaneval.py b/eval/humaneval.py index 8283f7ea..af510d61 100644 --- a/eval/humaneval.py +++ b/eval/humaneval.py @@ -52,6 +52,13 @@ "<|start_header_id|>assistant<|end_header_id|>\n\n" "Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ", " " + ), + "gemma": ( + "user\n" + "Complete the following Python function:\n\n{{problem}}<|eot_id|>" + "model\n" + "```python\n{{problem}} ", + " " ) } From 0122b1192f3557193e5ac86d5734334606b38460 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 10 Jul 2024 02:51:50 +0200 Subject: [PATCH 36/40] Option to launch eval script automatically after HumanEval test --- eval/humaneval.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/eval/humaneval.py b/eval/humaneval.py index af510d61..3240d57b 100644 --- a/eval/humaneval.py +++ b/eval/humaneval.py @@ -5,7 +5,7 @@ from exllamav2 import model_init from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8 from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler -import argparse, contextlib +import argparse, contextlib, subprocess import util # Args @@ -20,6 +20,7 @@ parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion") parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ") parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating") +parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling") model_init.add_args(parser) args = parser.parse_args() @@ -199,3 +200,8 @@ print(f" -- Saving: {args.output}") write_jsonl(args.output, samples) +# Optionally launch eval script + +if args.eval: + subprocess.run(["evaluate_functional_correctness", args.output]) + From 1179b8a5e5a62b31dcd07ba540132451af5d2e47 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 10 Jul 2024 08:05:57 +0200 Subject: [PATCH 37/40] Fix ppl test for long seq lengths --- exllamav2/model.py | 8 ++++++++ test_inference.py | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 624d7ac4..d21031df 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -681,6 +681,7 @@ def forward(self, return_last_state: bool = False, position_offsets: torch.Tensor | None = None, abort_event: threading.Event | None = None, + cpu_logits: bool = False, **kwargs) \ -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: """ @@ -717,6 +718,11 @@ def forward(self, :param abort_event: Optional event that, if set, will abort the forward pass. Function will return None if aborted. + :param cpu_logits: + If True, logits are collected and returned in system RAM. This is somewhat slower but can prevent + out-of-memory errors when computing logits for all positions in a long sequence, such as during a + perplexity test. + :return: FP16 logits tensor, shape (batch_size, q_len, vocab_size) (optional) state tensor, shape (batch_size, q_len, hidden_size) @@ -819,6 +825,8 @@ def forward(self, if abort_event and abort_event.is_set(): return if not _preprocess_only: + if cpu_logits: + r["logits"] = r["logits"].cpu() result = r["logits"] if result is None else torch.cat((result, r["logits"]), dim = 1) chunk_begin = chunk_end diff --git a/test_inference.py b/test_inference.py index 376b3f98..2c6f67ce 100644 --- a/test_inference.py +++ b/test_inference.py @@ -292,6 +292,8 @@ def ppl(input_ids__, logits__, lengths__, bins = False): + logits_device = model.modules[-1].device() + if bins: num_bins = (max(lengths__) + 255) // 256 logprob_sum_ = [0.0] * num_bins @@ -317,8 +319,8 @@ def ppl(input_ids__, logits__, lengths__, bins = False): a_ = b_ b_ = min(b_ + chunksize, logits_.shape[1]) - logits_f = logits_[:, a_:b_, :].float() + 1e-10 - target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device) + logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10 + target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device) log_probs = F.log_softmax(logits_f, dim=-1) token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) @@ -398,7 +400,7 @@ def ppl(input_ids__, logits__, lengths__, bins = False): input_ids = input_ids[:, :] if cache is not None: cache.current_seq_len = 0 - logits = model.forward(input_ids, cache) + logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048) logits = logits[:, :-1, :] logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens) From 25e46d8b305d5dd43bdb4ff59ee3d4e9b5cd26f5 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 11 Jul 2024 01:08:00 +0200 Subject: [PATCH 38/40] Dynamic gen: Fix return_logits=True when using CFG --- exllamav2/generator/dynamic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 14c68858..52924421 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -1932,7 +1932,7 @@ def emit( self.held_k_tokens.append(next_k_tokens) self.held_k_probs.append(next_k_probs) if self.return_logits: - self.held_logits.append(logits) + self.held_logits.append(logits[:1, :, :]) # Stop if we reach max_new_tokens From a8e3f77a838463d0bef0b185d867b91e07e1e071 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 11 Jul 2024 14:40:41 +0200 Subject: [PATCH 39/40] Add chunk_size cmdline arg to model_init --- exllamav2/model_init.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index fd85cbbb..17730ce2 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -22,6 +22,8 @@ def add_args(parser): parser.add_argument("-lq4", "--load_q4", action = "store_true", help = "Load weights in Q4 mode") parser.add_argument("-fst", "--fast_safetensors", action = "store_true", help = "Use alternative safetensors loader (with direct I/O when available)") parser.add_argument("-ic", "--ignore_compatibility", action = "store_true", help = "Do not override model config options in case of compatibility issues") + parser.add_argument("-chunk", "--chunk_size", type = int, help = "Chunk size ('input length')") + def print_options(args): @@ -41,6 +43,7 @@ def print_options(args): if args.experts_per_token is not None: print_opts += [f"experts_per_token: {args.experts_per_token}"] if args.load_q4: print_opts += ["load_q4"] if args.ignore_compatibility: print_opts += ["ignore_compatibility"] + if args.chunk_size is not None: print_opts += [f"chunk_size: {args.chunk_size}"] print(f" -- Options: {print_opts}") @@ -107,6 +110,10 @@ def init(args, if args.low_mem: config.set_low_mem() if args.load_q4: config.load_in_q4 = True + if args.chunk_size is not None: + config.max_input_len = args.chunk_size + config.max_attention_size = args.chunk_size ** 2 + # Compatibility warnings config.arch_compat_overrides(warn_only = args.ignore_compatibility) From b3e07eefe35197aa2e306fd2bdb26b426820cfc4 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 11 Jul 2024 15:11:30 +0200 Subject: [PATCH 40/40] Bump to 0.1.7 --- exllamav2/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/version.py b/exllamav2/version.py index 32efefd0..283b03a0 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.1.6" \ No newline at end of file +__version__ = "0.1.7" \ No newline at end of file