Skip to content

Commit

Permalink
Benchmark sequence parallelism in test_transformer_engine (#3546)
Browse files Browse the repository at this point in the history
```
$ nvidia-smi -L
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

$ mpirun -np 8 --output-filename /tmp/test_transformer_engine pytest tests/python/test_transformer_engine.py --only-mpi

$ cat /tmp/test_transformer_engine/1/rank.0/stdout

------------------------------------------------------------------------------------------ benchmark: 4 tests ------------------------------------------------------------------------------------------
Name (time in ms)                          Min                Max               Mean             StdDev            Median                IQR            Outliers       OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_layer[sp-forward]      2.2564 (1.0)      55.7794 (11.73)    13.2931 (3.01)     23.7547 (125.77)   2.6707 (1.05)     14.1577 (88.73)         1;1   75.2268 (0.33)          5           1
test_transformer_layer[tp-forward]      2.3941 (1.06)     18.6497 (3.92)      6.7947 (1.54)      7.0469 (37.31)    2.5476 (1.0)       8.2456 (51.68)         1;0  147.1742 (0.65)          5           1
test_transformer_layer[tp-backward]     4.2568 (1.89)      4.8231 (1.01)      4.4578 (1.01)      0.2570 (1.36)     4.2940 (1.69)      0.4091 (2.56)          1;0  224.3258 (0.99)          5           1
test_transformer_layer[sp-backward]     4.3135 (1.91)      4.7558 (1.0)       4.4221 (1.0)       0.1889 (1.0)      4.3292 (1.70)      0.1596 (1.0)           1;1  226.1393 (1.0)           5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```

Latency is neutral as expected.
  • Loading branch information
wujingyue authored Dec 10, 2024
1 parent 89c47f6 commit 8c82f30
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions tests/python/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ class ComputeType(Enum):
BACKWARD = auto()


class Parallelism(Enum):
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism
TENSOR_PARALLEL = auto()
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism
SEQUENCE_PARALLEL = auto()


@pytest.fixture(scope="module")
def setup_process_group(mpi_test) -> None:
# The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
Expand All @@ -47,7 +54,12 @@ def setup_process_group(mpi_test) -> None:
[ComputeType.FORWARD, ComputeType.BACKWARD],
ids=["forward", "backward"],
)
def test_transformer_layer(setup_process_group, benchmark, compute_type):
@pytest.mark.parametrize(
"parallelism",
[Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL],
ids=["tp", "sp"],
)
def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism):
# Hyperparameters for GPT-3
hidden_size = 12288
num_heads = 96
Expand All @@ -69,12 +81,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type):
# benchmark fails to execute on H100 with the default format (SBHD).
attn_input_format="bshd",
set_parallel_mode=True,
sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL),
tp_group=dist.group.WORLD,
)
transformer_layer.to(dtype).to("cuda")

match parallelism:
case Parallelism.TENSOR_PARALLEL:
local_sequence_length = sequence_length
case Parallelism.SEQUENCE_PARALLEL:
assert sequence_length % size == 0
local_sequence_length = sequence_length // size

x = torch.randn(
batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda"
batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda"
)

match compute_type:
Expand All @@ -93,7 +113,9 @@ def benchmark_fn(profile):

# Warmup.
y = benchmark_fn(False)
assert y.size() == torch.Size([batch_size, sequence_length, hidden_size])
assert y.size() == torch.Size(
[batch_size, local_sequence_length, hidden_size]
)

benchmark.pedantic(benchmark_fn, args=(True,), rounds=5)
case ComputeType.BACKWARD:
Expand Down

0 comments on commit 8c82f30

Please sign in to comment.