diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2f4bda3f77e..2ef33dcdf8f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -886,6 +886,9 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { GTEST_SKIP() << "Requires number of devices=" << D << " evenly divide H=" << H; } + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } auto dtype = GetParam(); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1125,6 +1128,9 @@ TEST_P(DistributedTransformerTest, Forward_SP) { GTEST_SKIP() << "Requires number of devices=" << D << " evenly divide H=" << H; } + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); auto fusion = std::make_unique();