Skip to content

Commit

Permalink
reuse workspace for some metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Sep 16, 2023
1 parent f6e06fa commit 5a84fd3
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1116,41 +1116,47 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) +
complex_size * sizeof(cuFloatComplex); // more components will
// be added here later
// memory can be shared across layers
size_t totalSharedSize =
infer_mode == TREE_VERIFY_MODE
? totalSize -
(key_cache_size + value_cache_size + qkv_max_proj_size) *
size_of_dt
: totalSize - (key_cache_size + value_cache_size) * size_of_dt;
// memory can't be shared across layers.
size_t instance_size =
size_of_dt *
(infer_mode == TREE_VERIFY_MODE
? key_cache_size + value_cache_size + qkv_max_proj_size
: key_cache_size + value_cache_size);
if (offload) {
// assert that we have enough reserved work space left
// memory can be shared across layers
size_t totalSharedSize =
infer_mode == TREE_VERIFY_MODE
? totalSize -
(key_cache_size + value_cache_size + qkv_max_proj_size) *
size_of_dt
: totalSize - (key_cache_size + value_cache_size) * size_of_dt;
// memory can't be shared across layers.
size_t instance_size =
size_of_dt *
(infer_mode == TREE_VERIFY_MODE
? key_cache_size + value_cache_size + qkv_max_proj_size
: key_cache_size + value_cache_size);
if (quantization_type != DT_NONE) {
totalSharedSize += quantized_weightSize;
}
assert(gpu_mem_allocator.reserved_total_size -
gpu_mem_allocator.reserved_allocated_size >=
totalSharedSize);
gpu_mem_allocator.create_legion_instance(reserveInst, instance_size);
} else {
gpu_mem_allocator.create_legion_instance(reserveInst, totalSize);
assert(handle.workSpaceSize >= totalSharedSize);
}
// in tree_verify, enable devQKVProjArray;
if (!offload || infer_mode == TREE_VERIFY_MODE) {
gpu_mem_allocator.create_legion_instance(reserveInst, instance_size);
// workspace for shared memory across layers
char *work_space_start_ptr = (char *)handle.workSpace;
// QKV need to be persistent in Tree_kernel.
if (infer_mode == TREE_VERIFY_MODE) {
devQKVProjArray = gpu_mem_allocator.allocate_instance_untyped(
qkv_max_proj_size * size_of_dt);
} else {
} else if (offload) {
devQKVProjArray = gpu_mem_allocator.allocate_reserved_untyped(
qkv_max_proj_size * size_of_dt);
// offset += qkv_max_proj_size * size_of_dt;
} else {
// spec/inc + non-offload
devQKVProjArray = work_space_start_ptr;
work_space_start_ptr += qkv_max_proj_size * size_of_dt;
}
// use key value cache in all mode.
Expand All @@ -1177,17 +1183,19 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
gpu_mem_allocator.allocate_reserved<cuFloatComplex>(complex_size);
// offset += complex_size * sizeof(cuFloatComplex);
} else {
token_infos =
gpu_mem_allocator.allocate_instance<BatchConfig::PerTokenInfo>(
tokeninfo_size);
qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size *
size_of_dt);
qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped(
qk_prod_size * size_of_dt);
attn_heads = gpu_mem_allocator.allocate_instance_untyped(attn_heads_size *
size_of_dt);
token_infos = static_cast<BatchConfig::PerTokenInfo *>(
(void *)work_space_start_ptr);
work_space_start_ptr +=
sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size;
qk_prods = work_space_start_ptr;
work_space_start_ptr += qk_prod_size * size_of_dt;
qk_prods_softmax = work_space_start_ptr;
work_space_start_ptr += qk_prod_size * size_of_dt;
attn_heads = work_space_start_ptr;
work_space_start_ptr += attn_heads_size * size_of_dt;
complex_input =
gpu_mem_allocator.allocate_instance<cuFloatComplex>(complex_size);
static_cast<cuFloatComplex *>((void *)work_space_start_ptr);
work_space_start_ptr += sizeof(cuFloatComplex) * complex_size;
}
// allocate more size for quantization data
Expand Down

0 comments on commit 5a84fd3

Please sign in to comment.