diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index e6726268..bb1f21b9 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -47,6 +47,27 @@ def block2batch(tensor, block_mapping, matmul_op=torch.matmul): return b2b_impl(tensor, block_mapping.t(), matmul_op) +def group_sum(partial_sum, block_mapping, block_groups, type="mme"): + if type == "mme": + sums = block2batch(partial_sum, block_mapping) + sums = batch2block(sums, block_mapping) + elif type == "reduce_sum": + # [num_blocks, 1, kv_heads, gqa] * [num_blocks, batch_size, 1, 1] + sums = partial_sum.unsqueeze(1) * block_mapping.view(block_mapping.shape[0], -1, 1, 1) + sums = sums.sum(dim=0) + # [batch_size, kv_heads, gqa] -> [num_blocks, kv_heads, gqa] + sums = sums.index_select(0, block_groups) + elif type == "reduce_sum_T": + partial_sum_T = partial_sum.permute(*range(1, partial_sum.dim()), 0) + block_mapping_T = block_mapping.t().view(block_mapping.shape[-1], 1, 1, block_mapping.shape[0]) + # [1, gqa, kv_heads, num_blocks] * [batch_size, 1, 1, num_blocks] + sums = partial_sum_T.unsqueeze(0) * block_mapping_T + sums = sums.sum(dim=-1) + # [batch_size, kv_heads, gqa] -> [num_blocks, kv_heads, gqa] + sums = sums.index_select(0, block_groups) + return sums + + def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_size, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op): # Normalize the attention scores @@ -63,8 +84,7 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s block_adjustment = (block_max - group_max).exp() sum_adjusted = block_sums.mul(block_adjustment) # Sum block's sums that belongs to the same sequeneces - group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) - group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) + group_sum_adjusted = group_sum(sum_adjusted, block_mapping, block_groups, type="reduce_sum_T") sum_adjusted = sum_adjusted.view(*adjustment_target_shape) group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) block_adjustment = block_adjustment.view(*adjustment_target_shape) @@ -169,7 +189,7 @@ def prompt_attention( softmax_op=torch.softmax, matmul_av_op=torch.matmul, valid_seq_lengths: Optional[torch.Tensor] = None, - fsdpa_op = None, + fsdpa_op=None, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -204,8 +224,8 @@ def prompt_attention( softmax_mode = 'fast' recompute_mode = True attn_weights = fsdpa_op(query, key, value, None, 0.0, True, - scale, softmax_mode, recompute_mode, - valid_seq_lengths, 'right') + scale, softmax_mode, recompute_mode, + valid_seq_lengths, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights