Skip to content

Commit

Permalink
Add back a default rtol.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Sep 16, 2024
1 parent d00686b commit 402b40d
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions tests/cpp/test_multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,21 @@ void validate(
// BFloat16 error is quite high, but the program has been verified with
// double precision to be logically correct.
const double atol = atols[i];
constexpr double kRtol = 1.6e-2;
auto generate_comparison_details = [](at::Tensor expected_out,
at::Tensor out,
double atol) -> std::string {
double atol,
double rtol) -> std::string {
std::ostringstream oss;
auto error = (out - expected_out).abs();
auto max_absolute_error = error.max().item().to<double>();
auto max_relative_error =
max_absolute_error / expected_out.abs().max().item().to<double>();
auto error_count = at::sum(error >= atol).item();
indent(oss, 1) << "max absolute error: " << max_absolute_error
<< std::endl;
(error.max() / expected_out.abs().max()).item().to<double>();
auto error_count =
at::sum(error >= atol + expected_out.abs() * rtol).item();
indent(oss, 1)
<< "max absolute error under rtol: "
<< (error - expected_out.abs() * rtol).max().item().to<double>()
<< std::endl;
indent(oss, 1) << "max relative error: " << max_relative_error
<< std::endl;
indent(oss, 1) << "failing elements: " << error_count << ", "
Expand All @@ -97,10 +101,11 @@ void validate(
return oss.str();
};

EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], /*rtol=*/0.0, atol))
EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], kRtol, atol))
<< "Output " << i << " mismatches with atol " << atol << ":"
<< std::endl
<< generate_comparison_details(expected_outputs[i], outputs[i], atol);
<< generate_comparison_details(
expected_outputs[i], outputs[i], atol, kRtol);
}
}

Expand Down Expand Up @@ -881,7 +886,7 @@ TEST_P(DistributedTransformerTest, MHA_Backward) {
validate(
expected_outputs,
out,
{1e-5, 0.02, 1e-5, 1e-5, 1e-5, 0.1, 0.1, 0.02, 0.01});
{1e-5, 0.02, 1e-5, 1e-4, 1e-4, 0.1, 0.1, 0.1, 0.01});
}

TEST_P(DistributedTransformerTest, Forward) {
Expand Down

0 comments on commit 402b40d

Please sign in to comment.