-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix reference implementations and improve validation. #2905
Changes from 8 commits
815b661
658c5d9
c6b3427
684c8b3
83d472f
98b4fbd
14ac700
5d8833b
d00686b
402b40d
bbfcf48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -276,24 +276,23 @@ TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x) { | |
|
||
auto x_sq = mul(x, x); | ||
auto x_cube = mul(x, x_sq); | ||
auto beta = IrBuilder::create<Val>(kBeta); | ||
auto kappa = IrBuilder::create<Val>(kKappa); | ||
auto one = IrBuilder::create<Val>(1.0); | ||
auto half = IrBuilder::create<Val>(0.5); | ||
|
||
auto inner_1 = mul(IrBuilder::create<Val>(kKappa), x_cube); | ||
auto inner_2 = add(x, inner_1); | ||
auto inner_3 = mul(IrBuilder::create<Val>(kBeta), inner_2); | ||
auto tanh_inner = tanh(inner_3); | ||
|
||
auto left = mul(IrBuilder::create<Val>(0.5), x); | ||
auto right = add(IrBuilder::create<Val>(1.0), tanh_inner); | ||
auto inner = mul(beta, add(x, mul(kappa, x_cube))); | ||
auto tanh_inner = tanh(inner); | ||
|
||
auto left_derivative = mul(IrBuilder::create<Val>(0.5), right); | ||
auto left = mul(half, x); | ||
auto right = add(one, tanh_inner); | ||
|
||
auto tanh_inner_sq = mul(tanh_inner, tanh_inner); | ||
auto tanh_derivative = sub(IrBuilder::create<Val>(1.0), tanh_inner_sq); | ||
auto left_derivative = mul(half, right); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this rewrite necessary? You mentioned it didn't affect accuracy like you had hypothesized so I'm wondering what this brings us? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not necessary and I'm happy to revert it. I kept it for only two reasons:
|
||
auto constant_mul_x_sq = | ||
mul(IrBuilder::create<Val>(kBeta * 3 * kKappa), x_sq); | ||
auto inner_derivative = add(IrBuilder::create<Val>(kBeta), constant_mul_x_sq); | ||
auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative)); | ||
auto tanh_derivative = sub(one, mul(tanh_inner, tanh_inner)); | ||
auto inner_derivative = | ||
mul(beta, add(one, mul(mul(IrBuilder::create<Val>(3.0), kappa), x_sq))); | ||
auto right_derivative = mul(mul(left, tanh_derivative), inner_derivative); | ||
|
||
auto dx = mul(dy, add(left_derivative, right_derivative)); | ||
return dx; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
// clang-format on | ||
#include <cmath> | ||
|
||
#include <gmock/gmock-matchers.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include <executor.h> | ||
|
@@ -56,33 +57,43 @@ class DistributedTransformerTest | |
}; | ||
|
||
namespace { | ||
// testValidate doesn't work out of the box due to #2906, so I had to manually | ||
// specify the tolerances. | ||
void validate( | ||
std::vector<at::Tensor> expected_out, | ||
std::vector<at::Tensor> out) { | ||
EXPECT_EQ(expected_out.size(), out.size()); | ||
for (auto i : c10::irange(out.size())) { | ||
const std::vector<at::Tensor>& expected_outputs, | ||
const std::vector<at::Tensor>& outputs, | ||
const std::vector<double>& rtols, | ||
const std::vector<double>& atols) { | ||
using testing::SizeIs; | ||
const auto num_outputs = outputs.size(); | ||
ASSERT_THAT(expected_outputs, SizeIs(num_outputs)); | ||
ASSERT_THAT(rtols, SizeIs(num_outputs)); | ||
ASSERT_THAT(atols, SizeIs(num_outputs)); | ||
|
||
for (const auto i : c10::irange(num_outputs)) { | ||
// allclose can catch this as well. However, it would throw an exception, | ||
// not showing which output was problematic. | ||
ASSERT_EQ(out[i].dtype(), expected_out[i].dtype()) | ||
ASSERT_EQ(outputs[i].dtype(), expected_outputs[i].dtype()) | ||
<< "Output " << i << " has a mismatching data type."; | ||
|
||
// Note: Scaling tolerance up since the error accumulates across ops | ||
// BFloat16 error is quite high, but the program has been verified with | ||
// double precision to be logically correct. | ||
const double atol = 0.075 * (i + 1); | ||
const double rtol = 1.6e-2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you remove rtol? Some of the absolute errors are 1e-5 which is a very small error of margin for bfloat16 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would also be beneficial to keep some form of default value, since tests are still getting added this adds to a lot of effort to get a simple example working. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed it because I found I ended up always setting rtol to 0 :) I'd love to keep some default value to, as you said, minimize the effort to get a simple example working. What would be a good default value? The old How about this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was supposed to have been taken from Pytorch. https://github.com/pytorch/pytorch/blob/042f2f7746a064f1527d95d1f1d712b4f0b34186/test/test_transformers.py#L85 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done -- added default rtols. |
||
auto generate_comparison_details = [](at::Tensor out, | ||
at::Tensor expected_out, | ||
double atol, | ||
double rtol) -> std::string { | ||
const double rtol = rtols[i]; | ||
const double atol = atols[i]; | ||
auto generate_comparison_details = [](at::Tensor expected_out, | ||
at::Tensor out, | ||
double rtol, | ||
double atol) -> std::string { | ||
std::ostringstream oss; | ||
auto error = (out - expected_out).abs(); | ||
auto max_error = error.max().item().to<double>(); | ||
auto max_absolute_error = error.max().item().to<double>(); | ||
auto max_relative_error = | ||
max_error / expected_out.abs().max().item().to<double>(); | ||
max_absolute_error / expected_out.abs().max().item().to<double>(); | ||
auto error_count = | ||
at::sum(error >= (atol + expected_out.abs() * rtol)).item(); | ||
indent(oss, 1) << "max error: " << max_error << std::endl; | ||
indent(oss, 1) << "max absolute error: " << max_absolute_error | ||
<< std::endl; | ||
indent(oss, 1) << "max relative error: " << max_relative_error | ||
<< std::endl; | ||
indent(oss, 1) << "failing elements: " << error_count << ", " | ||
|
@@ -91,9 +102,11 @@ void validate( | |
return oss.str(); | ||
}; | ||
|
||
EXPECT_TRUE(out[i].allclose(expected_out[i], rtol, atol)) | ||
<< "Output " << i << " mismatches:" << std::endl | ||
<< generate_comparison_details(out[i], expected_out[i], atol, rtol); | ||
EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], rtol, atol)) | ||
<< "Output " << i << " mismatches with rtol " << rtol << " and atol " | ||
<< atol << ":" << std::endl | ||
<< generate_comparison_details( | ||
expected_outputs[i], outputs[i], rtol, atol); | ||
} | ||
} | ||
|
||
|
@@ -143,7 +156,7 @@ std::vector<at::Tensor> reference_mlp_backwards( | |
at::Tensor w1) { | ||
auto at_dtype = w0.dtype(); | ||
// recompute activations | ||
auto linear0 = at::matmul(x, w0).add(b0).to(at::kFloat); | ||
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0; | ||
auto gelu = at::gelu(linear0, "tanh"); | ||
|
||
// backwards pass | ||
|
@@ -183,7 +196,7 @@ std::vector<at::Tensor> reference_mha_backwards( | |
at::Tensor w1) { | ||
auto at_dtype = w0.dtype(); | ||
// recompute up to sdpa | ||
auto linear0 = at::matmul(x, w0).add(b0).view({B, S, 3 * E}); | ||
auto linear0 = (at::matmul(x, w0).to(at::kFloat) + b0).view({B, S, 3 * E}); | ||
auto qkv = linear0.split(E, /*dim=*/-1); | ||
for (auto i = 0; i < 3; i++) { | ||
qkv[i] = qkv[i].reshape({B, S, H, E / H}).transpose(1, 2).to(at_dtype); | ||
|
@@ -428,8 +441,7 @@ std::vector<TensorView*> mlp_backwards( | |
// Activation recomputation | ||
TensorView* matmul0 = matmul(x, w0); | ||
TensorView* b0_bcast = broadcast(b0, {false, true, false}); | ||
TensorView* linear0 = add(matmul0, b0_bcast); | ||
linear0 = castOp(DataType::Float, linear0); | ||
TensorView* linear0 = add(matmul0, b0_bcast); // add generates float. | ||
TensorView* gelu = tanh_gelu(linear0); | ||
gelu = castOp(dtype, gelu); | ||
|
||
|
@@ -651,7 +663,11 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { | |
FusionExecutorCache fec(std::move(fusion)); | ||
at::manual_seed(getATenRandomSeed()); | ||
auto outputs = fec.runFusionWithInputs(inputs); | ||
validate(expected_outputs, outputs); | ||
validate( | ||
expected_outputs, | ||
outputs, | ||
{0.0, 0.0, 0.0, 0.0}, | ||
{0.01, 0.01, 0.01, 0.01}); | ||
} | ||
|
||
TEST_P(DistributedTransformerTest, MultiheadAttention) { | ||
|
@@ -707,8 +723,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { | |
|
||
FusionExecutorCache fec(std::move(fusion)); | ||
at::manual_seed(getATenRandomSeed()); | ||
auto out = fec.runFusionWithInputs(inputs); | ||
validate(expected_outputs, out); | ||
auto outputs = fec.runFusionWithInputs(inputs); | ||
validate( | ||
expected_outputs, | ||
outputs, | ||
{0.0, 0.0, 0.0, 0.0}, | ||
{0.02, 0.01, 0.01, 0.01}); | ||
} | ||
|
||
TEST_P(DistributedTransformerTest, MLP_Backward) { | ||
|
@@ -776,7 +796,11 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { | |
FusionExecutorCache fec(std::move(fusion)); | ||
auto outputs = fec.runFusionWithInputs(inputs); | ||
|
||
validate(expected_outputs, outputs); | ||
validate( | ||
expected_outputs, | ||
outputs, | ||
{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, | ||
{1e-5, 0.2, 1e-5, 0.01, 0.2, 0.01, 0.02}); | ||
} | ||
|
||
TEST_P(DistributedTransformerTest, MHA_Backward) { | ||
|
@@ -872,7 +896,11 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { | |
FusionExecutorCache fec(std::move(fusion)); | ||
at::manual_seed(getATenRandomSeed()); | ||
auto out = fec.runFusionWithInputs(inputs); | ||
validate(expected_outputs, out); | ||
validate( | ||
expected_outputs, | ||
out, | ||
{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, | ||
{1e-5, 0.02, 1e-5, 1e-5, 1e-5, 0.1, 0.1, 0.02, 0.01}); | ||
} | ||
|
||
TEST_P(DistributedTransformerTest, Forward) { | ||
|
@@ -981,7 +1009,11 @@ TEST_P(DistributedTransformerTest, Forward) { | |
FusionExecutorCache fec(std::move(fusion)); | ||
at::manual_seed(getATenRandomSeed()); | ||
auto outputs = fec.runFusionWithInputs(inputs); | ||
validate(expected_outputs, outputs); | ||
validate( | ||
expected_outputs, | ||
outputs, | ||
{0.0, 0.0, 0.0, 0.0, 0.0}, | ||
{1e-5, 0.01, 0.01, 0.02, 0.02}); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspected some inaccuracy and made these changes. Although they turned out not to affect accuracy, I left them in the PR as general cleanups.