From ee840c43ee130e2bb0a847cb787240bf5279fa97 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Sep 2024 22:58:19 -0700 Subject: [PATCH] Use default rtols for some other dtypes. --- tests/cpp/test_multidevice_transformer.cpp | 26 +++++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 920b649f776..b20dffb389a 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -74,11 +74,25 @@ void validate( ASSERT_EQ(outputs[i].dtype(), expected_outputs[i].dtype()) << "Output " << i << " has a mismatching data type."; - // Note: Scaling tolerance up since the error accumulates across ops - // BFloat16 error is quite high, but the program has been verified with - // double precision to be logically correct. const double atol = atols[i]; - constexpr double kRtol = 1.6e-2; + // These default rtols are copied from + // https://github.com/pytorch/pytorch/blob/951c21d6790334d57862e94a3f582ac724147a53/torch/testing/_comparison.py#L65-L73. + double rtol; + switch (outputs[i].scalar_type()) { + case at::kBFloat16: + rtol = 1.6e-2; + break; + case at::kHalf: + rtol = 1e-3; + break; + case at::kFloat: + rtol = 1.3e-6; + break; + default: + rtol = 0.0; + break; + } + auto generate_comparison_details = [](at::Tensor expected_out, at::Tensor out, double atol, @@ -101,11 +115,11 @@ void validate( return oss.str(); }; - EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], kRtol, atol)) + EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], rtol, atol)) << "Output " << i << " mismatches with atol " << atol << ":" << std::endl << generate_comparison_details( - expected_outputs[i], outputs[i], atol, kRtol); + expected_outputs[i], outputs[i], atol, rtol); } }