diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 92d78d2357b..920b649f776 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -78,17 +78,21 @@ void validate( // BFloat16 error is quite high, but the program has been verified with // double precision to be logically correct. const double atol = atols[i]; + constexpr double kRtol = 1.6e-2; auto generate_comparison_details = [](at::Tensor expected_out, at::Tensor out, - double atol) -> std::string { + double atol, + double rtol) -> std::string { std::ostringstream oss; auto error = (out - expected_out).abs(); - auto max_absolute_error = error.max().item().to(); auto max_relative_error = - max_absolute_error / expected_out.abs().max().item().to(); - auto error_count = at::sum(error >= atol).item(); - indent(oss, 1) << "max absolute error: " << max_absolute_error - << std::endl; + (error.max() / expected_out.abs().max()).item().to(); + auto error_count = + at::sum(error >= atol + expected_out.abs() * rtol).item(); + indent(oss, 1) + << "max absolute error under rtol: " + << (error - expected_out.abs() * rtol).max().item().to() + << std::endl; indent(oss, 1) << "max relative error: " << max_relative_error << std::endl; indent(oss, 1) << "failing elements: " << error_count << ", " @@ -97,10 +101,11 @@ void validate( return oss.str(); }; - EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], /*rtol=*/0.0, atol)) + EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], kRtol, atol)) << "Output " << i << " mismatches with atol " << atol << ":" << std::endl - << generate_comparison_details(expected_outputs[i], outputs[i], atol); + << generate_comparison_details( + expected_outputs[i], outputs[i], atol, kRtol); } } @@ -881,7 +886,7 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { validate( expected_outputs, out, - {1e-5, 0.02, 1e-5, 1e-5, 1e-5, 0.1, 0.1, 0.02, 0.01}); + {1e-5, 0.02, 1e-5, 1e-4, 1e-4, 0.1, 0.1, 0.1, 0.01}); } TEST_P(DistributedTransformerTest, Forward) {