diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index 00eb4b9eeb7..c8487d5a64e 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -77,8 +77,9 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type, paralle hidden_size, ffn_hidden_size, num_heads, - # https://github.com/NVIDIA/TransformerEngine/issues/1350: the - # benchmark fails to execute on H100 with the default format (SBHD). + # According to https://github.com/NVIDIA/TransformerEngine/issues/1350, + # `attn_input_format` has to match the format of `transformer_layer`'s + # input. attn_input_format="bshd", set_parallel_mode=True, sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL),