diff --git a/exllamav2/attn.py b/exllamav2/attn.py index fe7da4a8..f68faf7a 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -267,7 +267,7 @@ def scratch_space(self): def temp_state_size(self): - return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.hidden_size * 2 + 128 + return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.num_attention_heads * self.model.config.head_dim * 2 + 128 def temp_q_size(self): @@ -465,7 +465,7 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No v_states = None attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape((batch_size, q_len, hidden_size)) + attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim)) # Flash Attention 2 @@ -473,7 +473,7 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No # TODO: Enable flash-attn with input mask attn_output = flash_attn_func(q_states, k_states, v_states, causal = True) - attn_output = attn_output.reshape((batch_size, q_len, hidden_size)) + attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim)) # xformers memory_efficient_attention @@ -661,17 +661,16 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, past_le attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape((batch_size, q_len, hidden_size)) + attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim)) # Flash Attention 2 else: attn_output = flash_attn_func(query_states, key_states, value_states, causal = True) - attn_output = attn_output.reshape((batch_size, q_len, hidden_size)) + attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim)) # Update 8-bit cache - # TODO: Only update changed positions of the cache if cache is not None: cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) diff --git a/exllamav2/config.py b/exllamav2/config.py index caddf04c..07df7e10 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -97,6 +97,9 @@ def prepare(self, no_tensors = False): expect_keys_llama = [["lm_head"], ["model.norm"], ["model.embed_tokens"]] + expect_keys_gemma = [["model.norm"], + ["model.embed_tokens"]] + if "LlamaForCausalLM" in read_config["architectures"]: self.architecture = "Llama" @@ -157,6 +160,15 @@ def prepare(self, no_tensors = False): self.attention_bias_qkv = True self.attention_bias_o = False + elif "GemmaForCausalLM" in read_config["architectures"]: + self.architecture = "Gemma" + layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + expect_keys += \ + expect_keys_gemma + else: print(f" !! Warning, unknown architecture: {repr(read_config['architectures'])}") print(f" !! Loading as LlamaForCausalLM") @@ -206,7 +218,10 @@ def prepare(self, no_tensors = False): # Model dimensions - self.head_dim = self.hidden_size // self.num_attention_heads + if "head_dim" in read_config: + self.head_dim = read_config["head_dim"] + else: + self.head_dim = self.hidden_size // self.num_attention_heads # Create map of model tensors diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index 8b5add89..0ec2433c 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -68,6 +68,10 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No hidden_states = self.embedding.forward(hidden_states) + # Normalize the input embeddings for Gemma + if self.model.config.architecture == "Gemma": + hidden_states = hidden_states * (self.model.config.hidden_size ** 0.5) + if intermediates: return {"hidden_states": hidden_states} else: diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cu b/exllamav2/exllamav2_ext/cuda/q_attn.cu index 158fffd4..ce576394 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cu +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cu @@ -161,7 +161,7 @@ void QAttn::forward_cuda_2 half* lora_temp ) { - gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, hidden_size, false, temp_dq); + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, false, temp_dq); apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size); } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 1de12a9a..56950bb9 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -44,6 +44,20 @@ __device__ __forceinline__ half2 silu(half2 x) return result; } +__device__ __forceinline__ half gelu(half x) +{ + float xf = __half2float(x); + const float c = 0.797884560803f; // sqrt(2/Pi) + float tanh_arg = c * (xf + 0.044715f * pow(xf, 3)); + xf = 0.5f * xf * (1.0 + tanh(tanh_arg)); + return __float2half_rn(xf); +} + +__device__ __forceinline__ half2 gelu(half2 x) +{ + return __halves2half2(gelu(x.x), gelu(x.y)); +} + typedef void (*fp_silu_mul_kernel) ( half*, @@ -54,7 +68,7 @@ typedef void (*fp_silu_mul_kernel) const int ); -template +template __global__ void silu_mul_kernel ( half* __restrict__ x, @@ -90,7 +104,11 @@ __global__ void silu_mul_kernel half2 x_item = x_.item_half2(row, column); half2 y_item = y_.item_half2(row, column); - x_item = silu(x_item); + if constexpr (act_fn_gelu) + x_item = gelu(x_item); + else + x_item = silu(x_item); + x_item = __hmul2(x_item, y_item); x_.set_half2(row, column, x_item); @@ -100,19 +118,33 @@ __global__ void silu_mul_kernel half x_item = x_.item(row, column); half y_item = y_.item(row, column); - x_item = silu(x_item); + if constexpr (act_fn_gelu) + x_item = gelu(x_item); + else + x_item = silu(x_item); + x_item = __hmul(x_item, y_item); x_.set(row, column, x_item); } } -fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights) +fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights, bool act_fn_gelu) { - if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false>; - if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true>; - if (!use_half2 && !mul_r_weights) return silu_mul_kernel; - if (!use_half2 && mul_r_weights) return silu_mul_kernel; + if (act_fn_gelu) + { + if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, true>; + if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, true>; + if (!use_half2 && !mul_r_weights) return silu_mul_kernel; + if (!use_half2 && mul_r_weights) return silu_mul_kernel; + } + else + { + if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, false>; + if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, false>; + if (!use_half2 && !mul_r_weights) return silu_mul_kernel; + if (!use_half2 && mul_r_weights) return silu_mul_kernel; + } return NULL; }; @@ -129,7 +161,8 @@ QMLP::QMLP half* _temp_a, half* _temp_b, half* _temp_dq, - int _max_rows + int _max_rows, + bool _act_gelu ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -142,7 +175,8 @@ QMLP::QMLP temp_a(_temp_a), temp_b(_temp_b), temp_dq(_temp_dq), - max_rows(_max_rows) + max_rows(_max_rows), + act_gelu(_act_gelu) { } @@ -179,7 +213,7 @@ void QMLP::forward_ gridDim.x = DIVIDE(up->width, THREADS_X) / (use_half2 ? 2 : 1); gridDim.y = DIVIDE(rows, THREADS_Y); - fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false); + fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false, act_gelu); kernel<<>>(temp_a, temp_b, rows, intermediate_size, NULL, 0); gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, false, temp_dq); @@ -207,7 +241,8 @@ QMoEMLP::QMoEMLP half* _temp_logits, half* _temp_dq, int _max_rows, - int _hidden_dim + int _hidden_dim, + bool _act_gelu ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -226,7 +261,8 @@ QMoEMLP::QMoEMLP temp_logits(_temp_logits), temp_dq(_temp_dq), max_rows(_max_rows), - hidden_dim(_hidden_dim) + hidden_dim(_hidden_dim), + act_gelu(_act_gelu) { // for (int i = 0; i < num_experts; ++i) // { @@ -299,7 +335,7 @@ void QMoEMLP::forward_ if (rows <= MAX_Q_GEMM_WEIGHTS) { int intermediate_size = w1[0]->width; - fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true); + fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true, act_gelu); for (int i = 0; i < num_experts; i++) { diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh index e23d1d56..52249e0c 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh @@ -34,6 +34,8 @@ public: std::unordered_map> up_proj_lora; std::unordered_map> down_proj_lora; + bool act_gelu; + QMLP ( half* _layernorm, @@ -47,7 +49,8 @@ public: half* _temp_a, half* _temp_b, half* _temp_dq, - int _max_rows + int _max_rows, + bool _act_gelu ); ~QMLP(); @@ -94,6 +97,8 @@ public: int max_rows; int hidden_dim; + bool act_gelu; + // std::vector>> w1_lora; // std::vector>> w2_lora; // std::vector>> w3_lora; @@ -117,7 +122,8 @@ public: half* _temp_logits, half* _temp_dq, int _max_rows, - int _hidden_dim + int _hidden_dim, + bool _act_gelu ); ~QMoEMLP(); diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index 04b8a094..5945c3aa 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -48,7 +48,7 @@ uintptr_t make_q_attn if (qm_q_proj && !layernorm.is_meta()) TORCH_CHECK(qm_q_proj->height == layernorm.size(0), "q_proj is wrong shape") if (qm_k_proj && !layernorm.is_meta()) TORCH_CHECK(qm_k_proj->height == layernorm.size(0), "k_proj is wrong shape") if (qm_v_proj && !layernorm.is_meta()) TORCH_CHECK(qm_v_proj->height == layernorm.size(0), "v_proj is wrong shape") - if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->height == layernorm.size(0), "o_proj is wrong shape") + if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->width == layernorm.size(0), "o_proj is wrong shape") QAttn* attn = new QAttn ( diff --git a/exllamav2/exllamav2_ext/ext_qmlp.cpp b/exllamav2/exllamav2_ext/ext_qmlp.cpp index f91d2455..5125a21c 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.cpp +++ b/exllamav2/exllamav2_ext/ext_qmlp.cpp @@ -28,7 +28,8 @@ uintptr_t make_q_mlp torch::Tensor temp_a, torch::Tensor temp_b, torch::Tensor temp_dq, - int max_rows + int max_rows, + bool act_gelu ) { QMatrix* qm_gate = reinterpret_cast (q_gate); @@ -52,7 +53,8 @@ uintptr_t make_q_mlp (half*) temp_a.data_ptr(), (half*) temp_b.data_ptr(), (half*) temp_dq.data_ptr(), - max_rows + max_rows, + act_gelu ); return reinterpret_cast (mlp); @@ -163,7 +165,8 @@ uintptr_t make_q_moe_mlp torch::Tensor temp_b, torch::Tensor temp_logits, torch::Tensor temp_dq, - int max_rows + int max_rows, + bool act_gelu ) { std::vector qm_w1; @@ -202,7 +205,8 @@ uintptr_t make_q_moe_mlp (half*) temp_logits.data_ptr(), (half*) temp_dq.data_ptr(), max_rows, - hidden_dim + hidden_dim, + act_gelu ); return reinterpret_cast (moe_mlp); diff --git a/exllamav2/exllamav2_ext/ext_qmlp.h b/exllamav2/exllamav2_ext/ext_qmlp.h index 5119be47..6b8cac4b 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.h +++ b/exllamav2/exllamav2_ext/ext_qmlp.h @@ -12,7 +12,8 @@ uintptr_t make_q_mlp torch::Tensor temp_a, torch::Tensor temp_b, torch::Tensor temp_dq, - int max_rows + int max_rows, + bool act_gelu ); void free_q_mlp @@ -57,7 +58,8 @@ uintptr_t make_q_moe_mlp torch::Tensor temp_b, torch::Tensor temp_logits, torch::Tensor temp_dq, - int max_rows + int max_rows, + bool act_gelu ); void free_q_moe_mlp diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 451bcd9c..0c4ee859 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -91,7 +91,8 @@ def load(self): device_tensors.get_scratch_slice(self.temp_a_size()), device_tensors.get_scratch_slice(self.temp_b_size()), device_tensors.get_scratch_slice(self.temp_dq_size()), - self.model.config.max_input_len * self.model.config.max_batch_size) + self.model.config.max_input_len * self.model.config.max_batch_size, + self.model.config.architecture == "Gemma") def unload(self): @@ -195,7 +196,7 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, interme post_norm = self.post_attention_layernorm.forward(hidden_states) gate = self.gate_proj.forward(post_norm, loras = loras) - y = F.silu(gate) + y = F.gelu(gate) if self.model.config.architecture == "Gemma" else F.silu(gate) up = self.up_proj.forward(post_norm, loras = loras) y *= up y.clamp_(min = -65504.0, max = 65504.0) diff --git a/exllamav2/model.py b/exllamav2/model.py index b608628d..b3ee604d 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -164,10 +164,13 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): self.modules_dict[self.modules[-1].key] = self.modules[-1] self.head_layer_idx = len(self.modules) + self.modules.append(ExLlamaV2Linear(self, "lm_head", self.config.hidden_size, self.config.vocab_size, False)) self.modules_dict[self.modules[-1].key] = self.modules[-1] + if self.config.architecture == "Gemma": + self.modules[-1].alt_key = "model.embed_tokens" - # Find last layer that affects k/v cache + # Find last layer that affects k/v cache layer_idx = len(self.modules) while True: diff --git a/exllamav2/module.py b/exllamav2/module.py index 748c65c4..dd38a871 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -29,6 +29,7 @@ class ExLlamaV2Module: model = None config: ExLlamaV2Config key: str + alt_key: str = None device_idx: int footprint: int @@ -49,14 +50,14 @@ def device(self): return _torch_device(self.device_idx) - def load_multi(self, keys, override_key = None, measure = False): + def load_multi(self, key, keys, override_key = None, measure = False): tensors = {} submap = {} submap_i = {} size = 0 - key = self.key if override_key is None else override_key + # key = self.key if override_key is None else override_key for k in keys: ck = key + "." + k @@ -88,38 +89,45 @@ def load_multi(self, keys, override_key = None, measure = False): def load_weight(self, override_key = None): - key = self.key if override_key is None else override_key + if override_key is not None: + keys = [override_key] + else: + keys = [self.key] + if self.alt_key is not None: + keys += [self.alt_key] - # EXL2 + for key in keys: - if key + ".q_weight" in self.model.config.tensor_file_map: - qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], override_key = override_key) - qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int) - return qtensors + # EXL2 + + if key + ".q_weight" in self.model.config.tensor_file_map: + qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], override_key = override_key) + qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int) + return qtensors - # GPTQ + # GPTQ - if key + ".qweight" in self.model.config.tensor_file_map: - qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx", "bias"], override_key = override_key) - if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)): - del qtensors["bias"] - qtensors["scales"] = qtensors["scales"].half() - return qtensors + if key + ".qweight" in self.model.config.tensor_file_map: + qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"], override_key = override_key) + if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)): + del qtensors["bias"] + qtensors["scales"] = qtensors["scales"].half() + return qtensors - # Torch + # Torch - if key + ".weight" in self.model.config.tensor_file_map: - if key + ".bias" in self.model.config.tensor_file_map: - tensors = self.load_multi(["weight", "bias"], override_key = override_key) - tensor = tensors["weight"].half() - bias = tensors["bias"].half() - return nn.Parameter(tensor), nn.Parameter(bias) - else: - tensors = self.load_multi(["weight"], override_key = override_key) - tensor = tensors["weight"].half() - return nn.Parameter(tensor) + if key + ".weight" in self.model.config.tensor_file_map: + if key + ".bias" in self.model.config.tensor_file_map: + tensors = self.load_multi(key, ["weight", "bias"], override_key = override_key) + tensor = tensors["weight"].half() + bias = tensors["bias"].half() + return nn.Parameter(tensor), nn.Parameter(bias) + else: + tensors = self.load_multi(key, ["weight"], override_key = override_key) + tensor = tensors["weight"].half() + return nn.Parameter(tensor) - # No weights found for key + # No weights found for key return None @@ -128,24 +136,33 @@ def weight_footprint(self): if self.footprint == -1: - # EXL2 + keys = [self.key] + if self.alt_key is not None: + keys += [self.alt_key] - if self.key + ".q_weight" in self.model.config.tensor_file_map: - self.footprint = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "q_perm"], measure = True) + for key in keys: - # GPTQ + # EXL2 - elif self.key + ".qweight" in self.model.config.tensor_file_map: - self.footprint = self.load_multi(["qweight", "qzeros", "scales", "g_idx"], measure = True) + if key + ".q_weight" in self.model.config.tensor_file_map: + self.footprint = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "q_perm"], measure = True) - # Torch + # GPTQ + + elif key + ".qweight" in self.model.config.tensor_file_map: + self.footprint = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx"], measure = True) + + # Torch + + elif key + ".weight" in self.model.config.tensor_file_map: + self.footprint = self.load_multi(key, ["weight"], measure = True) - elif self.key + ".weight" in self.model.config.tensor_file_map: - self.footprint = self.load_multi(["weight"], measure = True) + if self.footprint != -1: break # Error - else: raise ValueError("Unknown tensor type: " + self.key) + if self.footprint == -1: + raise ValueError("Unknown tensor type: " + self.key) return self.footprint diff --git a/exllamav2/moe_mlp.py b/exllamav2/moe_mlp.py index 7860b2ff..007edd21 100644 --- a/exllamav2/moe_mlp.py +++ b/exllamav2/moe_mlp.py @@ -84,7 +84,8 @@ def load(self): device_tensors.get_scratch_slice(self.temp_b_size()), device_tensors.get_scratch_slice(self.temp_logit_size()), device_tensors.get_scratch_slice(self.temp_dq_size()), - self.model.config.max_input_len * self.model.config.max_batch_size) + self.model.config.max_input_len * self.model.config.max_batch_size, + self.model.config.architecture == "Gemma") def unload(self): @@ -237,7 +238,7 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, interme gate = self.w1[expert_idx].forward(current_state, loras = loras) up = self.w3[expert_idx].forward(current_state, loras = loras) - current_hidden_states = F.silu(gate) * up + current_hidden_states = (F.gelu(gate) if self.model.config.architecture == "Gemma" else F.silu(gate)) * up if intermediates: result[f"pre_down.{expert_idx}"] = current_hidden_states current_hidden_states = self.w2[expert_idx].forward(current_hidden_states, loras = loras) diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index 3144bb06..89a11a8b 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -36,6 +36,10 @@ def load(self): self.variance_epsilon = self.model.config.rms_norm_eps + # Gemma adds 1 to the norm tensor for some reason + if self.model.config.architecture == "Gemma": + self.weight += 1 + def unload(self): @@ -50,6 +54,10 @@ def unload(self): def get_weight(self): + # Make sure to return the original weight tensor for Gemma + if self.model.config.architecture == "Gemma": + return self.weight.data - 1 + return self.weight.data diff --git a/exllamav2/tokenizers/spm.py b/exllamav2/tokenizers/spm.py index 6a34aa01..f0e0a7e7 100644 --- a/exllamav2/tokenizers/spm.py +++ b/exllamav2/tokenizers/spm.py @@ -28,7 +28,7 @@ def enumerate_tokens(self): for i in range(self.vocab_size()): p = self.spm.id_to_piece(i) d = self.spm.decode(i) - if p.startswith(self.space_char()): d = " " + d + if p.startswith(self.space_char()) and not d.startswith(" "): d = " " + d self.vocab.append(d) return enumerate(self.vocab)