Skip to content

Commit

Permalink
Add parallelism as a parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Dec 9, 2024
1 parent 68bff16 commit 6a84d85
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions tests/python/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class ComputeType(Enum):
BACKWARD = auto()


class Parallelism(Enum):
TENSOR_PARALLEL = auto()
SEQUENCE_PARALLEL = auto()


@pytest.fixture(scope="module")
def setup_process_group(mpi_test) -> None:
# The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
Expand All @@ -47,7 +52,12 @@ def setup_process_group(mpi_test) -> None:
[ComputeType.FORWARD, ComputeType.BACKWARD],
ids=["forward", "backward"],
)
def test_transformer_layer(setup_process_group, benchmark, compute_type):
@pytest.mark.parametrize(
"parallelism",
[Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL],
ids=["tp", "sp"],
)
def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism):
# Hyperparameters for GPT-3
hidden_size = 12288
num_heads = 96
Expand All @@ -69,13 +79,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type):
# benchmark fails to execute on H100 with the default format (SBHD).
attn_input_format="bshd",
set_parallel_mode=True,
sequence_parallel=True,
sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL),
tp_group=dist.group.WORLD,
)
transformer_layer.to(dtype).to("cuda")

match parallelism:
case Parallelism.TENSOR_PARALLEL:
local_sequence_length = sequence_length
case Parallelism.SEQUENCE_PARALLEL:
assert sequence_length % size == 0
local_sequence_length = sequence_length // size

x = torch.randn(
batch_size, sequence_length // size, hidden_size, dtype=dtype, device="cuda"
batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda"
)

match compute_type:
Expand All @@ -95,7 +112,7 @@ def benchmark_fn(profile):
# Warmup.
y = benchmark_fn(False)
assert y.size() == torch.Size(
[batch_size, sequence_length // size, hidden_size]
[batch_size, local_sequence_length, hidden_size]
)

benchmark.pedantic(benchmark_fn, args=(True,), rounds=5)
Expand Down

0 comments on commit 6a84d85

Please sign in to comment.