Skip to content

Commit

Permalink
Use default rtols for some other dtypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Sep 16, 2024
1 parent 402b40d commit ee840c4
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions tests/cpp/test_multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,25 @@ void validate(
ASSERT_EQ(outputs[i].dtype(), expected_outputs[i].dtype())
<< "Output " << i << " has a mismatching data type.";

// Note: Scaling tolerance up since the error accumulates across ops
// BFloat16 error is quite high, but the program has been verified with
// double precision to be logically correct.
const double atol = atols[i];
constexpr double kRtol = 1.6e-2;
// These default rtols are copied from
// https://github.com/pytorch/pytorch/blob/951c21d6790334d57862e94a3f582ac724147a53/torch/testing/_comparison.py#L65-L73.
double rtol;
switch (outputs[i].scalar_type()) {
case at::kBFloat16:
rtol = 1.6e-2;
break;
case at::kHalf:
rtol = 1e-3;
break;
case at::kFloat:
rtol = 1.3e-6;
break;
default:
rtol = 0.0;
break;
}

auto generate_comparison_details = [](at::Tensor expected_out,
at::Tensor out,
double atol,
Expand All @@ -101,11 +115,11 @@ void validate(
return oss.str();
};

EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], kRtol, atol))
EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], rtol, atol))
<< "Output " << i << " mismatches with atol " << atol << ":"
<< std::endl
<< generate_comparison_details(
expected_outputs[i], outputs[i], atol, kRtol);
expected_outputs[i], outputs[i], atol, rtol);
}
}

Expand Down

0 comments on commit ee840c4

Please sign in to comment.