Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify SP - Opportunity to improve SP scalability #301

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 112 additions & 109 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,18 +577,24 @@ def __init__(self, config, layer_number,
else:
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)

self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
if parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
# self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
self.compute_attn_sp = DistributedAttention(self.compute_attn,
parallel_state.get_sequence_parallel_group(),
scatter_idx=2,
gather_idx=0,
hidden_size_per_attention_head=hidden_size_per_attention_head,
num_q_per_kv=self.num_key_value_groups if projection_size != kv_projection_size else -1)
self.compute_attn = lambda mixed_x_layer, *args, **kwargs: self.compute_attn_sp(mixed_x_layer, *args, **kwargs)

if self.use_flash_attn:
self.core_attention_flash = local_attn
else:
if self.use_flash_attn:
self.core_attention_flash = local_attn
else:
self.core_attention = local_attn
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
self.core_attention = local_attn
self.checkpoint_core_attention = config.recompute_granularity == 'selective'

# Output.
self.dense = tensor_parallel.RowParallelLinear(
Expand Down Expand Up @@ -650,78 +656,28 @@ def split_tensor(self, mixed_x_layer):

return query_layer, key_layer, value_layer

def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]

# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]

# =====================
# Query, Key, and Value
# =====================

if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = self.split_tensor(mixed_x_layer)

# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
else:
assert not self.use_gqa, 'GQA + cross-attn not tested yet'

# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)

# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)

# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)

def compute_attn(self,
mixed_x_layer,
attention_mask,
inference_params=None,
rotary_pos_emb=None):

# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = self.split_tensor(mixed_x_layer)

# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
# ==================================
# Adjust key and value for inference
# ==================================
Expand Down Expand Up @@ -786,43 +742,90 @@ def forward(self, hidden_states, attention_mask,
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

if self.enable_ds_sequence_parallel:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]

context_layer = self.dist_attn(query_layer, key_layer, value_layer)
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]

if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
if self.sequence_parallel:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
else:
context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask)
else:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]

if self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
else:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)

if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)

# =================
# Output. [sq, b, h]
# =================
def forward(self, hidden_states,
attention_mask,
encoder_output=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]

# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]

# =====================
# Query, Key, and Value
# =====================

if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

else:
assert not self.use_gqa, 'GQA + cross-attn not tested yet'

# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)

# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)

mixed_x_layer = torch.cat((query_layer, mixed_kv_layer), dim=-1).reshape(query_layer.size()[:-1], (-1,))

context_layer = self.compute_attn(mixed_x_layer,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)

output, bias = self.dense(context_layer)

Expand Down