Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reference implementations and improve validation. #2905

Merged
merged 11 commits into from
Sep 16, 2024
27 changes: 13 additions & 14 deletions csrc/ops/composite.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspected some inaccuracy and made these changes. Although they turned out not to affect accuracy, I left them in the PR as general cleanups.

Original file line number Diff line number Diff line change
Expand Up @@ -276,24 +276,23 @@ TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x) {

auto x_sq = mul(x, x);
auto x_cube = mul(x, x_sq);
auto beta = IrBuilder::create<Val>(kBeta);
auto kappa = IrBuilder::create<Val>(kKappa);
auto one = IrBuilder::create<Val>(1.0);
auto half = IrBuilder::create<Val>(0.5);

auto inner_1 = mul(IrBuilder::create<Val>(kKappa), x_cube);
auto inner_2 = add(x, inner_1);
auto inner_3 = mul(IrBuilder::create<Val>(kBeta), inner_2);
auto tanh_inner = tanh(inner_3);

auto left = mul(IrBuilder::create<Val>(0.5), x);
auto right = add(IrBuilder::create<Val>(1.0), tanh_inner);
auto inner = mul(beta, add(x, mul(kappa, x_cube)));
auto tanh_inner = tanh(inner);

auto left_derivative = mul(IrBuilder::create<Val>(0.5), right);
auto left = mul(half, x);
auto right = add(one, tanh_inner);

auto tanh_inner_sq = mul(tanh_inner, tanh_inner);
auto tanh_derivative = sub(IrBuilder::create<Val>(1.0), tanh_inner_sq);
auto left_derivative = mul(half, right);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this rewrite necessary? You mentioned it didn't affect accuracy like you had hypothesized so I'm wondering what this brings us?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary and I'm happy to revert it. I kept it for only two reasons:

  1. Code becomes shorter.
  2. Code matches the PyTorch implementation more closely, which I hypothesized would help accuracy.

auto constant_mul_x_sq =
mul(IrBuilder::create<Val>(kBeta * 3 * kKappa), x_sq);
auto inner_derivative = add(IrBuilder::create<Val>(kBeta), constant_mul_x_sq);
auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative));
auto tanh_derivative = sub(one, mul(tanh_inner, tanh_inner));
auto inner_derivative =
mul(beta, add(one, mul(mul(IrBuilder::create<Val>(3.0), kappa), x_sq)));
auto right_derivative = mul(mul(left, tanh_derivative), inner_derivative);

auto dx = mul(dy, add(left_derivative, right_derivative));
return dx;
Expand Down
86 changes: 59 additions & 27 deletions tests/cpp/test_multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <cmath>

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>

#include <executor.h>
Expand Down Expand Up @@ -56,33 +57,43 @@ class DistributedTransformerTest
};

namespace {
// testValidate doesn't work out of the box due to #2906, so I had to manually
// specify the tolerances.
void validate(
std::vector<at::Tensor> expected_out,
std::vector<at::Tensor> out) {
EXPECT_EQ(expected_out.size(), out.size());
for (auto i : c10::irange(out.size())) {
const std::vector<at::Tensor>& expected_outputs,
const std::vector<at::Tensor>& outputs,
const std::vector<double>& rtols,
const std::vector<double>& atols) {
using testing::SizeIs;
const auto num_outputs = outputs.size();
ASSERT_THAT(expected_outputs, SizeIs(num_outputs));
ASSERT_THAT(rtols, SizeIs(num_outputs));
ASSERT_THAT(atols, SizeIs(num_outputs));

for (const auto i : c10::irange(num_outputs)) {
// allclose can catch this as well. However, it would throw an exception,
// not showing which output was problematic.
ASSERT_EQ(out[i].dtype(), expected_out[i].dtype())
ASSERT_EQ(outputs[i].dtype(), expected_outputs[i].dtype())
<< "Output " << i << " has a mismatching data type.";

// Note: Scaling tolerance up since the error accumulates across ops
// BFloat16 error is quite high, but the program has been verified with
// double precision to be logically correct.
const double atol = 0.075 * (i + 1);
const double rtol = 1.6e-2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove rtol? Some of the absolute errors are 1e-5 which is a very small error of margin for bfloat16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be beneficial to keep some form of default value, since tests are still getting added this adds to a lot of effort to get a simple example working.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it because I found I ended up always setting rtol to 0 :)

I'd love to keep some default value to, as you said, minimize the effort to get a simple example working. What would be a good default value? The old atol=0.075*(i+1) is too relaxed after I fixed the reference implementation. The *(i+1) part is also problematic because the output number doesn't necessarily match the layer number.

How about this:

  1. Hardcode rtol to 0.016. I'll try to finetune this value a little bit -- 0.016 sounds like a large default rtol to start with.
  2. Still require the caller to provide a list of per-output atols, because it's hard to fine a good default and each test seems to require something different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done -- added default rtols.

auto generate_comparison_details = [](at::Tensor out,
at::Tensor expected_out,
double atol,
double rtol) -> std::string {
const double rtol = rtols[i];
const double atol = atols[i];
auto generate_comparison_details = [](at::Tensor expected_out,
at::Tensor out,
double rtol,
double atol) -> std::string {
std::ostringstream oss;
auto error = (out - expected_out).abs();
auto max_error = error.max().item().to<double>();
auto max_absolute_error = error.max().item().to<double>();
auto max_relative_error =
max_error / expected_out.abs().max().item().to<double>();
max_absolute_error / expected_out.abs().max().item().to<double>();
auto error_count =
at::sum(error >= (atol + expected_out.abs() * rtol)).item();
indent(oss, 1) << "max error: " << max_error << std::endl;
indent(oss, 1) << "max absolute error: " << max_absolute_error
<< std::endl;
indent(oss, 1) << "max relative error: " << max_relative_error
<< std::endl;
indent(oss, 1) << "failing elements: " << error_count << ", "
Expand All @@ -91,9 +102,11 @@ void validate(
return oss.str();
};

EXPECT_TRUE(out[i].allclose(expected_out[i], rtol, atol))
<< "Output " << i << " mismatches:" << std::endl
<< generate_comparison_details(out[i], expected_out[i], atol, rtol);
EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], rtol, atol))
<< "Output " << i << " mismatches with rtol " << rtol << " and atol "
<< atol << ":" << std::endl
<< generate_comparison_details(
expected_outputs[i], outputs[i], rtol, atol);
}
}

Expand Down Expand Up @@ -143,7 +156,7 @@ std::vector<at::Tensor> reference_mlp_backwards(
at::Tensor w1) {
auto at_dtype = w0.dtype();
// recompute activations
auto linear0 = at::matmul(x, w0).add(b0).to(at::kFloat);
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0;
auto gelu = at::gelu(linear0, "tanh");

// backwards pass
Expand Down Expand Up @@ -183,7 +196,7 @@ std::vector<at::Tensor> reference_mha_backwards(
at::Tensor w1) {
auto at_dtype = w0.dtype();
// recompute up to sdpa
auto linear0 = at::matmul(x, w0).add(b0).view({B, S, 3 * E});
auto linear0 = (at::matmul(x, w0).to(at::kFloat) + b0).view({B, S, 3 * E});
auto qkv = linear0.split(E, /*dim=*/-1);
for (auto i = 0; i < 3; i++) {
qkv[i] = qkv[i].reshape({B, S, H, E / H}).transpose(1, 2).to(at_dtype);
Expand Down Expand Up @@ -428,8 +441,7 @@ std::vector<TensorView*> mlp_backwards(
// Activation recomputation
TensorView* matmul0 = matmul(x, w0);
TensorView* b0_bcast = broadcast(b0, {false, true, false});
TensorView* linear0 = add(matmul0, b0_bcast);
linear0 = castOp(DataType::Float, linear0);
TensorView* linear0 = add(matmul0, b0_bcast); // add generates float.
TensorView* gelu = tanh_gelu(linear0);
gelu = castOp(dtype, gelu);

Expand Down Expand Up @@ -651,7 +663,11 @@ TEST_P(DistributedTransformerTest, MLP_Layer) {
FusionExecutorCache fec(std::move(fusion));
at::manual_seed(getATenRandomSeed());
auto outputs = fec.runFusionWithInputs(inputs);
validate(expected_outputs, outputs);
validate(
expected_outputs,
outputs,
{0.0, 0.0, 0.0, 0.0},
{0.01, 0.01, 0.01, 0.01});
}

TEST_P(DistributedTransformerTest, MultiheadAttention) {
Expand Down Expand Up @@ -707,8 +723,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) {

FusionExecutorCache fec(std::move(fusion));
at::manual_seed(getATenRandomSeed());
auto out = fec.runFusionWithInputs(inputs);
validate(expected_outputs, out);
auto outputs = fec.runFusionWithInputs(inputs);
validate(
expected_outputs,
outputs,
{0.0, 0.0, 0.0, 0.0},
{0.02, 0.01, 0.01, 0.01});
}

TEST_P(DistributedTransformerTest, MLP_Backward) {
Expand Down Expand Up @@ -776,7 +796,11 @@ TEST_P(DistributedTransformerTest, MLP_Backward) {
FusionExecutorCache fec(std::move(fusion));
auto outputs = fec.runFusionWithInputs(inputs);

validate(expected_outputs, outputs);
validate(
expected_outputs,
outputs,
{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
{1e-5, 0.2, 1e-5, 0.01, 0.2, 0.01, 0.02});
}

TEST_P(DistributedTransformerTest, MHA_Backward) {
Expand Down Expand Up @@ -872,7 +896,11 @@ TEST_P(DistributedTransformerTest, MHA_Backward) {
FusionExecutorCache fec(std::move(fusion));
at::manual_seed(getATenRandomSeed());
auto out = fec.runFusionWithInputs(inputs);
validate(expected_outputs, out);
validate(
expected_outputs,
out,
{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
{1e-5, 0.02, 1e-5, 1e-5, 1e-5, 0.1, 0.1, 0.02, 0.01});
}

TEST_P(DistributedTransformerTest, Forward) {
Expand Down Expand Up @@ -981,7 +1009,11 @@ TEST_P(DistributedTransformerTest, Forward) {
FusionExecutorCache fec(std::move(fusion));
at::manual_seed(getATenRandomSeed());
auto outputs = fec.runFusionWithInputs(inputs);
validate(expected_outputs, outputs);
validate(
expected_outputs,
outputs,
{0.0, 0.0, 0.0, 0.0, 0.0},
{1e-5, 0.01, 0.01, 0.02, 0.02});
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void testValidate(
line_number,
" in file ",
file_name,
".\n Detected abs error of: ",
".\n Detected max abs error of: ",
aten_output_tensor.sub(fusion_output_tensor)
.abs()
.max()
Expand Down
Loading