diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index de4734e6c90..00eb4b9eeb7 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -22,6 +22,13 @@ class ComputeType(Enum): BACKWARD = auto() +class Parallelism(Enum): + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism + TENSOR_PARALLEL = auto() + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism + SEQUENCE_PARALLEL = auto() + + @pytest.fixture(scope="module") def setup_process_group(mpi_test) -> None: # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. @@ -47,7 +54,12 @@ def setup_process_group(mpi_test) -> None: [ComputeType.FORWARD, ComputeType.BACKWARD], ids=["forward", "backward"], ) -def test_transformer_layer(setup_process_group, benchmark, compute_type): +@pytest.mark.parametrize( + "parallelism", + [Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL], + ids=["tp", "sp"], +) +def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism): # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 @@ -69,12 +81,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type): # benchmark fails to execute on H100 with the default format (SBHD). attn_input_format="bshd", set_parallel_mode=True, + sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL), tp_group=dist.group.WORLD, ) transformer_layer.to(dtype).to("cuda") + match parallelism: + case Parallelism.TENSOR_PARALLEL: + local_sequence_length = sequence_length + case Parallelism.SEQUENCE_PARALLEL: + assert sequence_length % size == 0 + local_sequence_length = sequence_length // size + x = torch.randn( - batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda" + batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda" ) match compute_type: @@ -93,7 +113,9 @@ def benchmark_fn(profile): # Warmup. y = benchmark_fn(False) - assert y.size() == torch.Size([batch_size, sequence_length, hidden_size]) + assert y.size() == torch.Size( + [batch_size, local_sequence_length, hidden_size] + ) benchmark.pedantic(benchmark_fn, args=(True,), rounds=5) case ComputeType.BACKWARD: