Skip to content

Commit

Permalink
validate checks dtypes. (#2917)
Browse files Browse the repository at this point in the history
This made validation stricter and indeed captured some precision
mismatches. Fixing those has probably made reference implementations
closer to actual.
  • Loading branch information
wujingyue authored Sep 7, 2024
1 parent 872aeb1 commit 7c9876a
Showing 1 changed file with 38 additions and 33 deletions.
71 changes: 38 additions & 33 deletions tests/cpp/test_multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,39 @@ void validate(
std::vector<at::Tensor> out) {
EXPECT_EQ(expected_out.size(), out.size());
for (auto i : c10::irange(out.size())) {
// allclose can catch this as well. However, it would throw an exception,
// not showing which output was problematic.
ASSERT_EQ(out[i].dtype(), expected_out[i].dtype())
<< "Output " << i << " has a mismatching data type.";

// Note: Scaling tolerance up since the error accumulates across ops
// BFloat16 error is quite high, but the program has been verified with
// double precision to be logically correct.
double atol = 0.075 * (i + 1);
double rtol = 1.6e-2;
auto all_close = out[i]
.to(expected_out[i].dtype())
.allclose(
expected_out[i],
rtol,
atol,
/*equal_nan=*/true);

if (!all_close) {
auto error = (out[i].to(expected_out[i].dtype()) - expected_out[i]).abs();
const double atol = 0.075 * (i + 1);
const double rtol = 1.6e-2;
auto generate_comparison_details = [](at::Tensor out,
at::Tensor expected_out,
double atol,
double rtol) -> std::string {
std::ostringstream oss;
auto error = (out - expected_out).abs();
auto max_error = error.max().item().to<double>();
auto max_relative_error =
(max_error / expected_out[i].abs().max()).item();
max_error / expected_out.abs().max().item().to<double>();
auto error_count =
at::sum(error >= (atol + expected_out[i].abs() * rtol)).item();
std::cout << "output[" << i << "] max error: " << max_error << std::endl;
std::cout << " max relative error: " << max_relative_error
<< std::endl;
std::cout << " failing elements: " << error_count << ", "
<< error_count.to<float>() / at::numel(out[i]) * 100.0
<< "\% of tensor" << std::endl;
}
EXPECT_TRUE(all_close);
at::sum(error >= (atol + expected_out.abs() * rtol)).item();
indent(oss, 1) << "max error: " << max_error << std::endl;
indent(oss, 1) << "max relative error: " << max_relative_error
<< std::endl;
indent(oss, 1) << "failing elements: " << error_count << ", "
<< error_count.to<float>() / at::numel(out) * 100.0
<< "\% of tensor";
return oss.str();
};

EXPECT_TRUE(out[i].allclose(expected_out[i], rtol, atol))
<< "Output " << i << " mismatches:" << std::endl
<< generate_comparison_details(out[i], expected_out[i], atol, rtol);
}
}

Expand All @@ -100,8 +105,8 @@ std::vector<at::Tensor> reference_mlp(
at::Tensor b1) {
auto at_dtype = w0.dtype();
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0;
auto gelu = at::gelu(linear0, "tanh");
auto linear1 = at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1;
auto gelu = at::gelu(linear0, "tanh").to(at_dtype);
auto linear1 = at::matmul(gelu, w1).to(at::kFloat) + b1;
auto dropout = at::dropout(linear1, kDropoutProb, true);
return {linear0, gelu, linear1, dropout};
}
Expand All @@ -113,7 +118,7 @@ std::vector<at::Tensor> reference_mha(
at::Tensor w1,
at::Tensor b1) {
auto at_dtype = w0.dtype();
auto m = at::matmul(x, w0).add(b0).view({B, S, 3 * E});
auto m = (at::matmul(x, w0).to(at::kFloat) + b0).view({B, S, 3 * E});
auto qkv_vec = m.split(E, 2);
for (auto i = 0; i < 3; i++) {
qkv_vec[i] =
Expand All @@ -124,7 +129,7 @@ std::vector<at::Tensor> reference_mha(
auto sdpa = std::get<0>(sdpa_out);
// Reassemble heads (B, H, S, E/H) to (B, S, H, E/H) to (B, S, E)
auto y = sdpa.transpose(1, 2).reshape({B * S, E});
auto y_proj = at::matmul(y, w1).add(b1);
auto y_proj = at::matmul(y, w1).to(at::kFloat) + b1;
auto y_dropout = at::dropout(y_proj.to(at::kFloat), kDropoutProb, true);
return {m, sdpa, y_proj, y_dropout};
}
Expand Down Expand Up @@ -649,7 +654,7 @@ TEST_P(DistributedTransformerTest, MLP_Layer) {
validate(expected_outputs, outputs);
}

TEST_P(DistributedTransformerTest, Multiheaded_Attention) {
TEST_P(DistributedTransformerTest, MultiheadAttention) {
auto dtype = GetParam();
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down Expand Up @@ -942,21 +947,21 @@ TEST_P(DistributedTransformerTest, Forward) {
at::manual_seed(getATenRandomSeed());
auto ln_1_ = at::native_layer_norm(
x_, norm_shape, /*weight=*/std::nullopt, /*bias=*/std::nullopt, kEps);
auto ln_1_out_ = std::get<0>(ln_1_).to(at_dtype);
auto ln_1_out_ = std::get<0>(ln_1_);

auto mha_out_ =
reference_mha(ln_1_out_, mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3];
auto mha_out_ = reference_mha(
ln_1_out_.to(at_dtype), mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3];
auto resid1_ = mha_out_ + x_;
auto ln_2_ = at::native_layer_norm(
resid1_,
norm_shape,
/*weight=*/std::nullopt,
/*bias=*/std::nullopt,
kEps);
auto ln_2_out_ = std::get<0>(ln_2_).to(at_dtype);
auto ln_2_out_ = std::get<0>(ln_2_);

auto mlp_out_ =
reference_mlp(ln_2_out_, mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3];
auto mlp_out_ = reference_mlp(
ln_2_out_.to(at_dtype), mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3];
auto at_out = mha_out_ + mlp_out_;

std::vector<c10::IValue> inputs = {
Expand Down

0 comments on commit 7c9876a

Please sign in to comment.