Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt at pipelining #78

Open
wants to merge 2 commits into
base: main_before_rebase
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,17 @@ def load_state_dict(self, state_dict, strict=True):
self.language_model.load_state_dict(state_dict, strict=strict)


def CrossEntropy(output, labels):
def CrossEntropy(outputs, labels):
args = get_args()
labels, loss_mask = labels[0], labels[1]

output, moe_loss = outputs[0], outputs[1]
args = get_args()

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
print(loss, moe_loss*args.moe_loss_coeff)
return loss + moe_loss * args.moe_loss_coeff


class GPTModelPipe(PipelineModule,MegatronModule):
Expand Down Expand Up @@ -226,15 +228,30 @@ def _to_float16(inputs):
else:
self.specs.append(lambda x: x.transpose(0, 1).contiguous())

num_experts = args.num_experts
assert len(num_experts) == 1 or len(num_experts) == self.num_layers // args.expert_interval, \
'num_experts must be either a single value or a list of the same length as the number of MoE layers'

# Create the list of MoE experts
if len(num_experts) == 1:
num_experts = num_experts * (args.num_layers // args.expert_interval)

for layer_idx in range(args.num_layers):
layer_num = layer_idx + 1
if layer_num % args.expert_interval == 0:
n_e = num_experts[(layer_num-1) // args.expert_interval]
else:
n_e = 1
print(f"Layer number {layer_idx} and Number of experts {n_e}")
self.specs.append(
LayerSpec(ParallelTransformerLayerPipe,
init_method=init_method,
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal))

layer_number=layer_num,
self_attn_mask_type=AttnMaskType.causal,
num_experts=n_e),
)

# Undo data format change
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
Expand Down
17 changes: 10 additions & 7 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ def __init__(self, init_method, output_layer_init_method,
eps=args.layernorm_epsilon)

self.num_experts = num_experts
# MLP
if self.num_experts <= 1:
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
Expand All @@ -452,14 +451,15 @@ def __init__(self, init_method, output_layer_init_method,
moe_mp_size = 1
else:
moe_mp_size = dist.get_world_size() // self.num_experts

ep_size = min(args.moe_expert_parallel_size, mpu.get_data_parallel_world_size())

self.mlp = MoE(args.hidden_size,
ParallelMLP(init_method,
output_layer_init_method=output_layer_init_method,
MOE=True,
MoE_mp_size=moe_mp_size),
num_experts=self.num_experts,
ep_size=args.moe_expert_parallel_size,
ep_size=ep_size,
k=args.topk,
use_residual=(args.mlp_type == 'residual'),
capacity_factor=args.moe_train_capacity_factor,
Expand Down Expand Up @@ -543,7 +543,6 @@ def forward(self, hidden_states, attention_mask,
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
mlp_output, moe_loss, _ = self.mlp(layernorm_output)

# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
Expand Down Expand Up @@ -595,13 +594,17 @@ def forward(self, inputs, **kwargs):
hidden_states, attention_mask = inputs, self._args.attn_mask
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(hidden_states, attention_mask, **kwargs)[0]
out, moe_loss = super().forward(hidden_states, attention_mask, **kwargs)
self.stashed_moe_loss = moe_loss * get_args().moe_loss_coeff
return out, moe_loss
elif len(inputs) == 2:
# Attention mask is an activation.
hidden_states, attention_mask = inputs[0], inputs[1]
hidden_states, curr_moe_loss, attention_mask = inputs[0], inputs[1], self._args.attn_mask
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(*inputs, **kwargs)[0], attention_mask
out, moe_loss = super().forward(hidden_states, attention_mask, **kwargs)

return out, moe_loss + curr_moe_loss
else:
raise RuntimeError('Received more inputs than understood.')

Expand Down