Skip to content

Commit

Permalink
replicate key&value
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 13, 2023
1 parent ffa5168 commit 552d49f
Show file tree
Hide file tree
Showing 14 changed files with 289 additions and 620 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
size_t weights_params, weightSize, biasSize, reserveSpaceSize,
quantized_weightSize;
int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize;
int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads;
int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, hidden_size;
bool *has_load_weights;
bool *apply_rotary_embedding;
bool *qkv_bias;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_ATTENTION_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_ATTENTION_KERNELS_H

#define QKV_WEIGHT_NUM 3
#define KV_WEIGHT_NUM 2

#include "flexflow/batch_config.h"
#include "flexflow/device.h"
#include "flexflow/fftype.h"
Expand Down
59 changes: 47 additions & 12 deletions inference/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ void load_attention_bias_v2(DT *ptr,
std::string weight_filepath = join_path({weights_folder, filename});

int n_heads = file_index == 0 ? num_heads : num_kv_heads;

int replicate_num = num_heads / num_kv_heads;

size_t qkv_partial_size = qkv_inner_dim * n_heads;
size_t qkv_replicate_size = qkv_inner_dim * num_heads;
size_t out_partial_size = hidden_dim;
size_t partial_size =
(file_index < 3) ? qkv_partial_size : out_partial_size;
Expand All @@ -178,13 +182,24 @@ void load_attention_bias_v2(DT *ptr,

size_t data_index = 0;

for (int i = 0; i < partial_size; i++) {
ptr[idx + i] = host_array.at(data_index);
data_index++;
// q, o
if (file_index == 0 || file_index == 3) {
for (int i = 0; i < partial_size; i++) {
ptr[idx + i] = host_array.at(data_index);
data_index++;
}
} else {
// k, v
for (int i = 0; i < partial_size; i++) {
for (int j = 0; j < replicate_num; j++) {
ptr[idx + j * partial_size + i] = host_array.at(data_index);
}
data_index++;
}
}

file_index++;
idx += qkv_partial_size;
idx += qkv_replicate_size;

in.close();
}
Expand Down Expand Up @@ -220,9 +235,14 @@ void load_attention_weights_v2(DT *ptr,
size_t k_size = single_proj_size * num_kv_heads,
v_size = single_proj_size * num_kv_heads;

size_t k_replicate_size = one_weight_file_size;
size_t v_replicate_size = one_weight_file_size;

int replicate_num = num_heads / num_kv_heads;

// stride for q, k, v, o
size_t stride_size =
(q_size + v_size + k_size + o_size) / tensor_parallelism_degree;
size_t stride_size = (q_size + v_replicate_size + k_replicate_size + o_size) /
tensor_parallelism_degree;
for (auto filename : weight_filenames) {
std::cout << "Loading weight file " << filename << std::endl;
std::string weight_filepath = join_path({weights_folder, filename});
Expand All @@ -231,7 +251,8 @@ void load_attention_weights_v2(DT *ptr,
size_t partial_size = (file_index == 0 || file_index == 3)
? one_weight_file_size
: single_proj_size * num_kv_heads;
size_t one_partition_size = partial_size / tensor_parallelism_degree;
size_t one_partition_size =
one_weight_file_size / tensor_parallelism_degree;

std::ifstream in(weight_filepath, std::ios::in | std::ios::binary);
if (!in.good()) {
Expand All @@ -252,16 +273,30 @@ void load_attention_weights_v2(DT *ptr,
assert(false && "data size mismatch");
}
// wq, wk, wo
for (int i = 0; i < tensor_parallelism_degree; i++) {
for (int j = 0; j < one_partition_size; j++) {
ptr[base_index + i * stride_size + j] = host_array.at(data_index++);
if (file_index == 0) {
for (int i = 0; i < tensor_parallelism_degree; i++) {
for (int j = 0; j < one_partition_size; j++) {
ptr[base_index + i * stride_size + j] = host_array.at(data_index++);
}
}
} else {
for (int i = 0; i < num_heads; i++) {
int kv_idx = i / (num_heads / num_kv_heads);
int head_idx = i % (num_heads / tensor_parallelism_degree);
int tp_idx = (i / (num_heads / tensor_parallelism_degree));
for (int j = 0; j < single_proj_size; j++) {
ptr[base_index + tp_idx * stride_size + single_proj_size * head_idx +
j] = host_array.at(kv_idx * single_proj_size + j);
}
}
}
assert(data_index == partial_size);

// assert(data_index == partial_size);
base_index += one_partition_size;
file_index++;
}
assert(base_index == (q_size + k_size + v_size) / tensor_parallelism_degree);
assert(base_index == (q_size + k_replicate_size + v_replicate_size) /
tensor_parallelism_degree);

{
std::cout << "Loading weight file " << o_file << std::endl;
Expand Down
4 changes: 1 addition & 3 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ void FALCON::create_falcon_model(FFModel &ff,
falcon_config.print();

if (ff.config.tensor_parallelism_degree > falcon_config.n_head ||
falcon_config.n_head % ff.config.tensor_parallelism_degree != 0 ||
ff.config.tensor_parallelism_degree > falcon_config.n_head_kv ||
falcon_config.n_head_kv % ff.config.tensor_parallelism_degree != 0) {
falcon_config.n_head % ff.config.tensor_parallelism_degree != 0) {
assert(false && "The number of attention heads is smaller, or it is not "
"divisible by the tensor parallelism degree");
}
Expand Down
8 changes: 0 additions & 8 deletions python/flexflow/serve/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(
raise ValueError(
f"Number of q attention heads ({self.falcon_config.n_head}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})"
)
if (
self.falcon_config.n_head_kv < self.ffconfig.tensor_parallelism_degree
or self.falcon_config.n_head_kv % self.ffconfig.tensor_parallelism_degree
!= 0
):
raise ValueError(
f"Number of k/v attention heads ({self.falcon_config.n_head_kv}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})"
)

self.build_model(max_tokens_per_batch)

Expand Down
8 changes: 0 additions & 8 deletions python/flexflow/serve/models/starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(
raise ValueError(
f"Number of attention heads ({self.starcoder_config.num_attention_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})"
)
if (
self.starcoder_config.n_head_kv < self.ffconfig.tensor_parallelism_degree
or self.starcoder_config.n_head_kv % self.ffconfig.tensor_parallelism_degree
!= 0
):
raise ValueError(
f"Number of k/v attention heads ({self.starcoder_config.n_head_kv}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})"
)

self.build_model(max_tokens_per_batch)

Expand Down
16 changes: 9 additions & 7 deletions src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input,
int kParas = kProjSize * kSize;
int vParas = vProjSize * vSize;
int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize);
int weight_size = qParas * num_q_heads + kParas * num_kv_heads +
vParas * num_kv_heads + oParas * num_q_heads;

// allocate num_q_heads for key, value for replication
int weight_size = qParas * num_q_heads + kParas * num_q_heads +
vParas * num_q_heads + oParas * num_q_heads;
int one_head_size = qParas + kParas + vParas + oParas;

{
Expand All @@ -177,7 +179,7 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input,
if (qkv_bias || final_bias) {
// q, k, v, o
int qkv_bias_size =
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_kv_heads;
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads;
int dims[1] = {(qkv_bias ? qkv_bias_size : 0) +
(final_bias ? oProjSize : 0)};
li->weights[1] = create_weight_legion_ordering(1,
Expand Down Expand Up @@ -348,7 +350,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
dims[0].size = dims[0].degree;
dims[1] = inputs[0]->dims[num_dims - 1];
dims[1].size = this->num_q_heads * (qParas + oParas) +
this->num_kv_heads * (kParas + vParas);
this->num_q_heads * (kParas + vParas);
dims[1].is_replica_dim = false;

if (quantization_type != DT_NONE) {
Expand All @@ -367,7 +369,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
if (qkv_bias || final_bias) {
ParallelTensorShape bias_shape = _input->get_shape();
int qkv_bias_size =
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_kv_heads;
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads;
bias_shape.dims[0].size =
(qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0);
bias_shape.dims[1].size = bias_shape.dims[2].size = 1;
Expand Down Expand Up @@ -461,7 +463,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
dims[0].size = dims[0].degree;
dims[1] = inputs[0]->dims[num_dims - 1];
dims[1].size = this->num_q_heads * (qParas + oParas) +
this->num_kv_heads * (kParas + vParas);
this->num_q_heads * (kParas + vParas);
dims[1].is_replica_dim = false;
// dims[2].size = this->num_q_heads * (qParas + oParas) + this->num_kv_heads
// * (kParas + vParas);
Expand All @@ -481,7 +483,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention(
if (qkv_bias || final_bias) {
ParallelTensorShape bias_shape = _input->get_shape();
int qkv_bias_size =
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_kv_heads;
qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads;
bias_shape.dims[0].size =
(qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0);
bias_shape.dims[1].size = bias_shape.dims[2].size = 1;
Expand Down
Loading

0 comments on commit 552d49f

Please sign in to comment.