Skip to content

Releases: awslabs/slapo

v0.0.3

23 Mar 19:05
1e57052
Compare
Choose a tag to compare

This release mainly improves

  1. Fix some fidelity issues.
  2. Refactor schedule primitives, and add .fork_rng(), .annotate(), and .replace_all() primitives.
  3. Other bug fixing.

If any of the following cases match your existing schedule based on v0.0.2, you have to change them to support v0.0.3.

  1. Tagging parameters for DeepSpeed pipeline runtime to perform an additional all-reduce on TP group. For example, you may have the following code snippet that tags LayerNorm parameters:
def tag_layernorm(sch):
    for m in sch.mod.modules():
        if isinstance(m, nn.LayerNorm):
            for p in m.parameters(recurse=False):
                p.replicated_param = True

This can be changed to the following in v0.0.3:

def annotate_layernorm_and_bias(sch):
    for sub_sch in sch.child.values():
        if isinstance(sub_sch.mod, nn.LayerNorm):
            for name, _ in sub_sch.mod.named_parameters(recurse=False):
                sub_sch.annotate(name, "replicated_param", True)
        if issubclass(sub_sch.mod.__class__, LinearWithSyncFunc):
            sub_sch.annotate("bias", "replicated_param", True)
        annotate_layernorm_and_bias(sub_sch)

Reference: https://github.com/awslabs/slapo/blob/main/slapo/model_schedule/gpt2.py#L529

  1. RNG control can be done easily with a new introduced schedule primitive .fork_rng(). Accordingly, the old slapo.op.AttentionOpWithRNG is removed. If you have the following code snippet:
new_op = AttentionOpWithRNG(
	sub_sch["module"]["attn_op"].mod.attn_op_name,
    sub_sch["module"]["attn_op"].mod.apply_causal_mask,
    sub_sch["module"]["attn_op"].mod.scale,
)
sub_sch["module"]["attn_op"].replace(new_op)

It has to be changed to

sub_sch["module"]["attn_op"].fork_rng()
  1. The primitive .trace_for_pipeline() has been renamed to .trace_until(). Since the arguments remain the same, you could simply replace all occurrences.

  2. If you use slapo.op.FusedMLP with sharding, you need to change your schedule to reflect the change of FusedMLP implementation. For example:

fc_names = ["fc_in", "act", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[1]].shard("bias", axis=0)
sub_sch[fc_names[2]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[2]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")

changes to

fc_names = ["fc_in", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[0]].shard("bias", axis=0)
sub_sch[fc_names[1]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[1]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")

What's Changed

Full Changelog: v0.0.2...v0.0.3

v0.0.2

20 Feb 20:21
5ae9047
Compare
Choose a tag to compare

This release mainly improves

  1. More unit tests.
  2. Add .fuse and related primitives.
  3. Improve overall training efficiency of GPT models by adding sequence parallelism, tie weight supports, etc.
  4. Documentation and tutorials.
  5. Bug fixing.

What's Changed

New Contributors

Full Changelog: v0.0.1...v0.0.2

First release of v0.0.1

25 Jan 18:40
36c0f46
Compare
Choose a tag to compare

First release of v0.0.1

What's Changed

New Contributors

Full Changelog: https://github.com/awslabs/slapo/commits/v0.0.1