Skip to content

Commit

Permalink
Support Gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Feb 22, 2024
1 parent a19a2ec commit cc1094a
Show file tree
Hide file tree
Showing 15 changed files with 170 additions and 74 deletions.
11 changes: 5 additions & 6 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def scratch_space(self):

def temp_state_size(self):

return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.hidden_size * 2 + 128
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.num_attention_heads * self.model.config.head_dim * 2 + 128


def temp_q_size(self):
Expand Down Expand Up @@ -465,15 +465,15 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No
v_states = None

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))

# Flash Attention 2

else:

# TODO: Enable flash-attn with input mask
attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))

# xformers memory_efficient_attention

Expand Down Expand Up @@ -661,17 +661,16 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, past_le
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))

# Flash Attention 2

else:

attn_output = flash_attn_func(query_states, key_states, value_states, causal = True)
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))

# Update 8-bit cache
# TODO: Only update changed positions of the cache

if cache is not None:
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)
Expand Down
17 changes: 16 additions & 1 deletion exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def prepare(self, no_tensors = False):
expect_keys_llama = [["lm_head"],
["model.norm"],
["model.embed_tokens"]]
expect_keys_gemma = [["model.norm"],
["model.embed_tokens"]]


if "LlamaForCausalLM" in read_config["architectures"]:
self.architecture = "Llama"
Expand Down Expand Up @@ -157,6 +160,15 @@ def prepare(self, no_tensors = False):
self.attention_bias_qkv = True
self.attention_bias_o = False

elif "GemmaForCausalLM" in read_config["architectures"]:
self.architecture = "Gemma"
layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
expect_keys += \
expect_keys_gemma

else:
print(f" !! Warning, unknown architecture: {repr(read_config['architectures'])}")
print(f" !! Loading as LlamaForCausalLM")
Expand Down Expand Up @@ -206,7 +218,10 @@ def prepare(self, no_tensors = False):

# Model dimensions

self.head_dim = self.hidden_size // self.num_attention_heads
if "head_dim" in read_config:
self.head_dim = read_config["head_dim"]
else:
self.head_dim = self.hidden_size // self.num_attention_heads

# Create map of model tensors

Expand Down
4 changes: 4 additions & 0 deletions exllamav2/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No

hidden_states = self.embedding.forward(hidden_states)

# Normalize the input embeddings for Gemma
if self.model.config.architecture == "Gemma":
hidden_states = hidden_states * (self.model.config.hidden_size ** 0.5)

if intermediates:
return {"hidden_states": hidden_states}
else:
Expand Down
2 changes: 1 addition & 1 deletion exllamav2/exllamav2_ext/cuda/q_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void QAttn::forward_cuda_2
half* lora_temp
)
{
gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, hidden_size, false, temp_dq);
gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, false, temp_dq);

apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size);
}
64 changes: 50 additions & 14 deletions exllamav2/exllamav2_ext/cuda/q_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ __device__ __forceinline__ half2 silu(half2 x)
return result;
}

__device__ __forceinline__ half gelu(half x)
{
float xf = __half2float(x);
const float c = 0.797884560803f; // sqrt(2/Pi)
float tanh_arg = c * (xf + 0.044715f * pow(xf, 3));
xf = 0.5f * xf * (1.0 + tanh(tanh_arg));
return __float2half_rn(xf);
}

__device__ __forceinline__ half2 gelu(half2 x)
{
return __halves2half2(gelu(x.x), gelu(x.y));
}

typedef void (*fp_silu_mul_kernel)
(
half*,
Expand All @@ -54,7 +68,7 @@ typedef void (*fp_silu_mul_kernel)
const int
);

template <bool use_half2, bool use_r_weights>
template <bool use_half2, bool use_r_weights, bool act_fn_gelu>
__global__ void silu_mul_kernel
(
half* __restrict__ x,
Expand Down Expand Up @@ -90,7 +104,11 @@ __global__ void silu_mul_kernel
half2 x_item = x_.item_half2(row, column);
half2 y_item = y_.item_half2(row, column);

x_item = silu(x_item);
if constexpr (act_fn_gelu)
x_item = gelu(x_item);
else
x_item = silu(x_item);

x_item = __hmul2(x_item, y_item);

x_.set_half2(row, column, x_item);
Expand All @@ -100,19 +118,33 @@ __global__ void silu_mul_kernel
half x_item = x_.item(row, column);
half y_item = y_.item(row, column);

x_item = silu(x_item);
if constexpr (act_fn_gelu)
x_item = gelu(x_item);
else
x_item = silu(x_item);

x_item = __hmul(x_item, y_item);

x_.set(row, column, x_item);
}
}

fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights)
fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights, bool act_fn_gelu)
{
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false>;
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true>;
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false>;
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true>;
if (act_fn_gelu)
{
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, true>;
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, true>;
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false, true>;
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true, true>;
}
else
{
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, false>;
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, false>;
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false, false>;
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true, false>;
}
return NULL;
};

Expand All @@ -129,7 +161,8 @@ QMLP::QMLP
half* _temp_a,
half* _temp_b,
half* _temp_dq,
int _max_rows
int _max_rows,
bool _act_gelu
):
layernorm(_layernorm),
layernorm_bias(_layernorm_bias),
Expand All @@ -142,7 +175,8 @@ QMLP::QMLP
temp_a(_temp_a),
temp_b(_temp_b),
temp_dq(_temp_dq),
max_rows(_max_rows)
max_rows(_max_rows),
act_gelu(_act_gelu)
{
}

Expand Down Expand Up @@ -179,7 +213,7 @@ void QMLP::forward_
gridDim.x = DIVIDE(up->width, THREADS_X) / (use_half2 ? 2 : 1);
gridDim.y = DIVIDE(rows, THREADS_Y);

fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false);
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false, act_gelu);
kernel<<<gridDim, blockDim>>>(temp_a, temp_b, rows, intermediate_size, NULL, 0);

gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, false, temp_dq);
Expand Down Expand Up @@ -207,7 +241,8 @@ QMoEMLP::QMoEMLP
half* _temp_logits,
half* _temp_dq,
int _max_rows,
int _hidden_dim
int _hidden_dim,
bool _act_gelu
):
layernorm(_layernorm),
layernorm_bias(_layernorm_bias),
Expand All @@ -226,7 +261,8 @@ QMoEMLP::QMoEMLP
temp_logits(_temp_logits),
temp_dq(_temp_dq),
max_rows(_max_rows),
hidden_dim(_hidden_dim)
hidden_dim(_hidden_dim),
act_gelu(_act_gelu)
{
// for (int i = 0; i < num_experts; ++i)
// {
Expand Down Expand Up @@ -299,7 +335,7 @@ void QMoEMLP::forward_
if (rows <= MAX_Q_GEMM_WEIGHTS)
{
int intermediate_size = w1[0]->width;
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true);
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true, act_gelu);

for (int i = 0; i < num_experts; i++)
{
Expand Down
10 changes: 8 additions & 2 deletions exllamav2/exllamav2_ext/cuda/q_mlp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public:
std::unordered_map<uintptr_t, std::tuple<half*, half*, int>> up_proj_lora;
std::unordered_map<uintptr_t, std::tuple<half*, half*, int>> down_proj_lora;

bool act_gelu;

QMLP
(
half* _layernorm,
Expand All @@ -47,7 +49,8 @@ public:
half* _temp_a,
half* _temp_b,
half* _temp_dq,
int _max_rows
int _max_rows,
bool _act_gelu
);

~QMLP();
Expand Down Expand Up @@ -94,6 +97,8 @@ public:
int max_rows;
int hidden_dim;

bool act_gelu;

// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w1_lora;
// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w2_lora;
// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w3_lora;
Expand All @@ -117,7 +122,8 @@ public:
half* _temp_logits,
half* _temp_dq,
int _max_rows,
int _hidden_dim
int _hidden_dim,
bool _act_gelu
);

~QMoEMLP();
Expand Down
2 changes: 1 addition & 1 deletion exllamav2/exllamav2_ext/ext_qattn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ uintptr_t make_q_attn
if (qm_q_proj && !layernorm.is_meta()) TORCH_CHECK(qm_q_proj->height == layernorm.size(0), "q_proj is wrong shape")
if (qm_k_proj && !layernorm.is_meta()) TORCH_CHECK(qm_k_proj->height == layernorm.size(0), "k_proj is wrong shape")
if (qm_v_proj && !layernorm.is_meta()) TORCH_CHECK(qm_v_proj->height == layernorm.size(0), "v_proj is wrong shape")
if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->height == layernorm.size(0), "o_proj is wrong shape")
if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->width == layernorm.size(0), "o_proj is wrong shape")

QAttn* attn = new QAttn
(
Expand Down
12 changes: 8 additions & 4 deletions exllamav2/exllamav2_ext/ext_qmlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ uintptr_t make_q_mlp
torch::Tensor temp_a,
torch::Tensor temp_b,
torch::Tensor temp_dq,
int max_rows
int max_rows,
bool act_gelu
)
{
QMatrix* qm_gate = reinterpret_cast<QMatrix*> (q_gate);
Expand All @@ -52,7 +53,8 @@ uintptr_t make_q_mlp
(half*) temp_a.data_ptr(),
(half*) temp_b.data_ptr(),
(half*) temp_dq.data_ptr(),
max_rows
max_rows,
act_gelu
);

return reinterpret_cast<uintptr_t> (mlp);
Expand Down Expand Up @@ -163,7 +165,8 @@ uintptr_t make_q_moe_mlp
torch::Tensor temp_b,
torch::Tensor temp_logits,
torch::Tensor temp_dq,
int max_rows
int max_rows,
bool act_gelu
)
{
std::vector<QMatrix*> qm_w1;
Expand Down Expand Up @@ -202,7 +205,8 @@ uintptr_t make_q_moe_mlp
(half*) temp_logits.data_ptr(),
(half*) temp_dq.data_ptr(),
max_rows,
hidden_dim
hidden_dim,
act_gelu
);

return reinterpret_cast<uintptr_t> (moe_mlp);
Expand Down
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/ext_qmlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ uintptr_t make_q_mlp
torch::Tensor temp_a,
torch::Tensor temp_b,
torch::Tensor temp_dq,
int max_rows
int max_rows,
bool act_gelu
);

void free_q_mlp
Expand Down Expand Up @@ -57,7 +58,8 @@ uintptr_t make_q_moe_mlp
torch::Tensor temp_b,
torch::Tensor temp_logits,
torch::Tensor temp_dq,
int max_rows
int max_rows,
bool act_gelu
);

void free_q_moe_mlp
Expand Down
5 changes: 3 additions & 2 deletions exllamav2/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def load(self):
device_tensors.get_scratch_slice(self.temp_a_size()),
device_tensors.get_scratch_slice(self.temp_b_size()),
device_tensors.get_scratch_slice(self.temp_dq_size()),
self.model.config.max_input_len * self.model.config.max_batch_size)
self.model.config.max_input_len * self.model.config.max_batch_size,
self.model.config.architecture == "Gemma")


def unload(self):
Expand Down Expand Up @@ -195,7 +196,7 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, interme
post_norm = self.post_attention_layernorm.forward(hidden_states)

gate = self.gate_proj.forward(post_norm, loras = loras)
y = F.silu(gate)
y = F.gelu(gate) if self.model.config.architecture == "Gemma" else F.silu(gate)
up = self.up_proj.forward(post_norm, loras = loras)
y *= up
y.clamp_(min = -65504.0, max = 65504.0)
Expand Down
5 changes: 4 additions & 1 deletion exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,13 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
self.modules_dict[self.modules[-1].key] = self.modules[-1]

self.head_layer_idx = len(self.modules)

self.modules.append(ExLlamaV2Linear(self, "lm_head", self.config.hidden_size, self.config.vocab_size, False))
self.modules_dict[self.modules[-1].key] = self.modules[-1]
if self.config.architecture == "Gemma":
self.modules[-1].alt_key = "model.embed_tokens"

# Find last layer that affects k/v cache
# Find last layer that affects k/v cache

layer_idx = len(self.modules)
while True:
Expand Down
Loading

0 comments on commit cc1094a

Please sign in to comment.