Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply flash attention on vision encoder #339

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,8 @@ python3 chatglm_cpp/convert.py -i THUDM/glm-4-9b-chat -t q4_0 -o models/chatglm4
You may use `-vt <vision_type>` to set quantization type for the vision encoder. It is recommended to run GLM4V on GPU since vision encoding runs too slow on CPU even with 4-bit quantization.
```sh
python3 chatglm_cpp/convert.py -i THUDM/glm-4v-9b -t q4_0 -vt q4_0 -o models/chatglm4v-ggml.bin
./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常之处" --temp 0
# 这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。
# 通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。
# 此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的荒谬性。
./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常的地方" --temp 0
# 这张图片中不寻常的地方在于,男子正在一辆黄色出租车后面熨衣服。通常情况下,熨衣是在家中或洗衣店进行的,而不是在车辆上。此外,出租车在行驶中,男子却能够稳定地熨衣,这增加了场景的荒诞感。
```

</details>
Expand Down
177 changes: 118 additions & 59 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,9 @@ static ggml_tensor *apply_rotary_emb_basic(ModelContext *mctx, ggml_tensor *laye
// tensor a (activation) is of shape [s, #h, d]
// tensor b (position_ids) is of shape [s]
ggml_context *ctx = mctx->ctx_b.get();
#ifdef GGML_USE_CUDA
if (!ggml_is_contiguous(layer)) {
if (ggml_cpu_has_cuda() && !ggml_is_contiguous(layer)) {
layer = ggml_cont(ctx, layer);
}
#endif
const int head_size = layer->ne[0];
layer = ggml_rope_ext_inplace(ctx, layer, position_ids, nullptr, head_size, (int)rope_type, 0, rope_theta, 1.0f,
0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d]
Expand All @@ -568,18 +566,20 @@ static ggml_tensor *apply_rotary_emb_glm(ModelContext *mctx, ggml_tensor *layer,

ggml_tensor *a1_rope = a1;
ggml_tensor *a2_rope = a2;
#ifdef GGML_USE_CUDA
a1_rope = ggml_cont(ctx, a1_rope);
a2_rope = ggml_cont(ctx, a2_rope);
#endif

if (ggml_cpu_has_cuda()) {
a1_rope = ggml_cont(ctx, a1_rope);
a2_rope = ggml_cont(ctx, a2_rope);
}

a1_rope = ggml_rope_inplace(ctx, a1_rope, b1, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2]
a2_rope = ggml_rope_inplace(ctx, a2_rope, b2, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2]

#ifdef GGML_USE_CUDA
a1_rope = ggml_cpy(ctx, a1_rope, a1);
a2_rope = ggml_cpy(ctx, a2_rope, a2);
#endif
if (ggml_cpu_has_cuda()) {
a1_rope = ggml_cpy(ctx, a1_rope, a1);
a2_rope = ggml_cpy(ctx, a2_rope, a2);
}

ggml_build_forward_expand(mctx->gf, a1_rope);
ggml_build_forward_expand(mctx->gf, a2_rope);

Expand All @@ -599,15 +599,15 @@ static ggml_tensor *apply_rotary_emb_glm2(ModelContext *mctx, ggml_tensor *layer
ggml_view_3d(ctx, layer, rope_dim, layer->ne[1], layer->ne[2], layer->nb[1], layer->nb[2], 0);

ggml_tensor *half_layer = half_layer_view;
#ifdef GGML_USE_CUDA
half_layer = ggml_cont(ctx, half_layer);
#endif
if (ggml_cpu_has_cuda()) {
half_layer = ggml_cont(ctx, half_layer);
}
ggml_tensor *roped_half_layer =
ggml_rope_ext_inplace(ctx, half_layer, position_ids, nullptr, rope_dim, (int)RopeType::GPTJ, 0, rope_theta,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d]
#ifdef GGML_USE_CUDA
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
#endif
if (ggml_cpu_has_cuda()) {
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
}
ggml_build_forward_expand(mctx->gf, roped_half_layer);

return layer;
Expand Down Expand Up @@ -677,6 +677,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
key_layer = ggml_permute(ctx, key_layer, 0, 2, 1, 3); // [#kvh, s, d]
value_layer = ggml_permute(ctx, value_layer, 1, 2, 0, 3); // [#kvh, d, s]

ggml_tensor *context_layer;
if (k_cache && v_cache) {
// store key & value to cache
ggml_tensor *k_cache_view =
Expand All @@ -695,46 +696,47 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_key_value_heads,
v_cache->nb[1], v_cache->nb[2],
0); // [#kvh, d, kvs]
} else {
key_layer = ggml_cont(ctx, key_layer);
value_layer = ggml_cont(ctx, value_layer);
}

// attention
query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size));
ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs]
// attention
query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size));
ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs]

if (n_past == 0) {
// build attention mask for context input
if (num_shared_q_heads > 1) {
attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen,
num_attention_heads); // [#h, s, kvs]
}
if (n_past == 0) {
// build attention mask for context input
if (num_shared_q_heads > 1) {
attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen,
num_attention_heads); // [#h, s, kvs]
}

if (attn_mask_type == AttentionMaskType::BIDIRECTIONAL) {
// pass
} else if (attn_mask_type == AttentionMaskType::CAUSAL) {
attn_scores = ggml_diag_mask_inf_inplace(ctx, attn_scores, num_virtual_tokens + n_past);
} else {
attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask);
if (attention_mask) {
attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask);
}

if (num_shared_q_heads > 1) {
attn_scores =
ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen,
num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs]
}
}

ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs]

context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d]
if (num_shared_q_heads > 1) {
attn_scores =
ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen,
num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs]
context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen,
num_attention_heads); // [#h, s, d]
}
context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d]
} else {
// qkv must be correctly padded
key_layer = ggml_cast(ctx, key_layer, GGML_TYPE_F16); // [#kvh, s, d]
value_layer = ggml_cast(ctx, ggml_permute(ctx, value_layer, 1, 0, 2, 3), GGML_TYPE_F16); // [#kvh, s, d]
context_layer = ggml_flash_attn_ext(ctx, query_layer, key_layer, value_layer, attention_mask,
1.f / std::sqrt(head_size), 0);
ggml_flash_attn_ext_set_prec(context_layer, GGML_PREC_F32);
}

ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs]

ggml_tensor *context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d]
if (num_shared_q_heads > 1) {
context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen,
num_attention_heads); // [#h, s, d]
}
context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d]
context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d]
context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d]

ggml_tensor *attn_output = dense.forward(mctx, context_layer);
return attn_output;
Expand Down Expand Up @@ -1341,6 +1343,19 @@ void ChatGLM2Model::set_graph_inputs(ggml_cgraph *gf, const std::vector<int> &in
std::vector<int> position_ids_buffer(position_ids->ne[0]);
std::iota(position_ids_buffer.begin(), position_ids_buffer.end(), n_past);
ggml_backend_tensor_set(position_ids, position_ids_buffer.data(), 0, position_ids_buffer.size() * sizeof(int));

ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask");
if (attention_mask) {
const int kvlen = attention_mask->ne[0];
const int qlen = attention_mask->ne[1];
std::vector<float> mask_buf(qlen * kvlen);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < kvlen; j++) {
mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f;
}
}
ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask));
}
}

StateDict ChatGLM2ForCausalLM::state_dict() const {
Expand Down Expand Up @@ -1827,14 +1842,14 @@ EVA2CLIPTransformer::EVA2CLIPTransformer(ModelContext *mctx, const VisionModelCo
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
layers.emplace_back(mctx, config.dtype, config.hidden_size, config.num_attention_heads,
config.num_attention_heads, config.intermediate_size, config.num_positions, config.norm_eps,
config.hidden_act, true, true, false, RopeType::DISABLED, -1,
AttentionMaskType::BIDIRECTIONAL, 0, false);
config.hidden_act, true, true, false, RopeType::DISABLED, -1, 0, false);
}
}

ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states) const {
ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states,
ggml_tensor *attention_mask) const {
for (const auto &layer : layers) {
hidden_states = layer.forward(mctx, hidden_states, nullptr, nullptr, 0);
hidden_states = layer.forward(mctx, hidden_states, attention_mask, nullptr, 0);
}
return hidden_states;
}
Expand All @@ -1843,17 +1858,29 @@ ggml_tensor *EVA2CLIPModel::forward(ModelContext *mctx, ggml_tensor *input) cons
ggml_context *ctx = mctx->ctx_b.get();

ggml_tensor *hidden_states = patch_embedding.forward(mctx, input);
hidden_states = transformer.forward(mctx, hidden_states); // [s, hd]

const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - 1));
// padding for flash attn
const int pad_to_multiple_of = ggml_cpu_has_cuda() ? 256 : GGML_KQ_MASK_PAD;
const int pad_size = GGML_PAD(hidden_states->ne[1], pad_to_multiple_of) - hidden_states->ne[1];
if (pad_size) {
hidden_states = ggml_pad(ctx, hidden_states, 0, pad_size, 0, 0);
}

ggml_tensor *encoder_attention_mask =
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_states->ne[1], hidden_states->ne[1]);
ggml_set_input(encoder_attention_mask);
ggml_set_name(encoder_attention_mask, "encoder_attention_mask");

encoder_attention_mask = ggml_cast(ctx, encoder_attention_mask, GGML_TYPE_F16);
hidden_states = transformer.forward(mctx, hidden_states, encoder_attention_mask); // [s, hd]

const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - pad_size - 1));
hidden_states = ggml_view_3d(ctx, hidden_states, hidden_states->ne[0], grid_size, grid_size, hidden_states->nb[1],
grid_size * hidden_states->nb[1], hidden_states->nb[1]); // [g, g, hd]
// TODO: must use this cont?
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g]
hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2]
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g]
hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2]
hidden_states = ggml_reshape_2d(ctx, hidden_states, hidden_states->ne[0] * hidden_states->ne[1],
hidden_states->ne[2]); // [hd, s]
// TODO: this cont?
hidden_states->ne[2]); // [hd, s]
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 0, 2, 3)); // [s, hd]

hidden_states = linear_proj.forward(mctx, hidden_states);
Expand Down Expand Up @@ -1967,6 +1994,38 @@ void ChatGLM4VModel::set_graph_inputs(ggml_cgraph *gf, const std::vector<int> &i
// copy to tensor
ggml_backend_tensor_set(image_tensor, pixels_f32.data(), 0, ggml_nbytes(image_tensor));
}

// attention_mask
ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask");
if (attention_mask) {
const int kvlen = attention_mask->ne[0];
const int qlen = attention_mask->ne[1];
std::vector<float> mask_buf(qlen * kvlen);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < kvlen; j++) {
mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f;
}
}
ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask));
}

// encoder_attention_mask
ggml_tensor *encoder_attention_mask = ggml_graph_get_tensor(gf, "encoder_attention_mask");
if (encoder_attention_mask) {
const int valid_tokens = vision.patch_embedding.num_positions();
const int M = encoder_attention_mask->ne[1];
const int N = encoder_attention_mask->ne[0];
std::vector<float> encoder_mask_f32(M * N);
CHATGLM_CHECK((size_t)ggml_nelements(encoder_attention_mask) == encoder_mask_f32.size());
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
encoder_mask_f32[i * N + j] =
(i < valid_tokens && j < valid_tokens) ? 0.f : -65504.f; // -INFINITY causes nan/inf logits
}
}
ggml_backend_tensor_set(encoder_attention_mask, encoder_mask_f32.data(), 0,
ggml_nbytes(encoder_attention_mask));
}
}

int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const std::optional<Image> &image) const {
Expand Down
Loading
Loading