From 6b5b7b81341a032e11691aae92b27d48c0390169 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Tue, 19 Nov 2024 18:23:23 -0800 Subject: [PATCH] Skip sequence parallel test when D=1 (#3451) Add gtest skip when for distributed transformer sequence parallel when number of devices=1. --- tests/cpp/test_multidevice_transformer.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2f4bda3f77e..2ef33dcdf8f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -886,6 +886,9 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { GTEST_SKIP() << "Requires number of devices=" << D << " evenly divide H=" << H; } + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } auto dtype = GetParam(); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1125,6 +1128,9 @@ TEST_P(DistributedTransformerTest, Forward_SP) { GTEST_SKIP() << "Requires number of devices=" << D << " evenly divide H=" << H; } + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); auto fusion = std::make_unique();