diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index c947d4ebef..25eda080ec 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1116,41 +1116,47 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) + complex_size * sizeof(cuFloatComplex); // more components will // be added here later + // memory can be shared across layers + size_t totalSharedSize = + infer_mode == TREE_VERIFY_MODE + ? totalSize - + (key_cache_size + value_cache_size + qkv_max_proj_size) * + size_of_dt + : totalSize - (key_cache_size + value_cache_size) * size_of_dt; + // memory can't be shared across layers. + size_t instance_size = + size_of_dt * + (infer_mode == TREE_VERIFY_MODE + ? key_cache_size + value_cache_size + qkv_max_proj_size + : key_cache_size + value_cache_size); + if (offload) { // assert that we have enough reserved work space left - // memory can be shared across layers - size_t totalSharedSize = - infer_mode == TREE_VERIFY_MODE - ? totalSize - - (key_cache_size + value_cache_size + qkv_max_proj_size) * - size_of_dt - : totalSize - (key_cache_size + value_cache_size) * size_of_dt; - // memory can't be shared across layers. - size_t instance_size = - size_of_dt * - (infer_mode == TREE_VERIFY_MODE - ? key_cache_size + value_cache_size + qkv_max_proj_size - : key_cache_size + value_cache_size); - if (quantization_type != DT_NONE) { totalSharedSize += quantized_weightSize; } assert(gpu_mem_allocator.reserved_total_size - gpu_mem_allocator.reserved_allocated_size >= totalSharedSize); - gpu_mem_allocator.create_legion_instance(reserveInst, instance_size); } else { - gpu_mem_allocator.create_legion_instance(reserveInst, totalSize); + assert(handle.workSpaceSize >= totalSharedSize); } - // in tree_verify, enable devQKVProjArray; - if (!offload || infer_mode == TREE_VERIFY_MODE) { + gpu_mem_allocator.create_legion_instance(reserveInst, instance_size); + // workspace for shared memory across layers + char *work_space_start_ptr = (char *)handle.workSpace; + + // QKV need to be persistent in Tree_kernel. + if (infer_mode == TREE_VERIFY_MODE) { devQKVProjArray = gpu_mem_allocator.allocate_instance_untyped( qkv_max_proj_size * size_of_dt); - } else { + } else if (offload) { devQKVProjArray = gpu_mem_allocator.allocate_reserved_untyped( qkv_max_proj_size * size_of_dt); - // offset += qkv_max_proj_size * size_of_dt; + } else { + // spec/inc + non-offload + devQKVProjArray = work_space_start_ptr; + work_space_start_ptr += qkv_max_proj_size * size_of_dt; } // use key value cache in all mode. @@ -1177,17 +1183,19 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( gpu_mem_allocator.allocate_reserved(complex_size); // offset += complex_size * sizeof(cuFloatComplex); } else { - token_infos = - gpu_mem_allocator.allocate_instance( - tokeninfo_size); - qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * - size_of_dt); - qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( - qk_prod_size * size_of_dt); - attn_heads = gpu_mem_allocator.allocate_instance_untyped(attn_heads_size * - size_of_dt); + token_infos = static_cast( + (void *)work_space_start_ptr); + work_space_start_ptr += + sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; + qk_prods = work_space_start_ptr; + work_space_start_ptr += qk_prod_size * size_of_dt; + qk_prods_softmax = work_space_start_ptr; + work_space_start_ptr += qk_prod_size * size_of_dt; + attn_heads = work_space_start_ptr; + work_space_start_ptr += attn_heads_size * size_of_dt; complex_input = - gpu_mem_allocator.allocate_instance(complex_size); + static_cast((void *)work_space_start_ptr); + work_space_start_ptr += sizeof(cuFloatComplex) * complex_size; } // allocate more size for quantization data