diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 40b199aca49..40767d383f7 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -276,24 +276,23 @@ TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x) { auto x_sq = mul(x, x); auto x_cube = mul(x, x_sq); + auto beta = IrBuilder::create(kBeta); + auto kappa = IrBuilder::create(kKappa); + auto one = IrBuilder::create(1.0); + auto half = IrBuilder::create(0.5); - auto inner_1 = mul(IrBuilder::create(kKappa), x_cube); - auto inner_2 = add(x, inner_1); - auto inner_3 = mul(IrBuilder::create(kBeta), inner_2); - auto tanh_inner = tanh(inner_3); - - auto left = mul(IrBuilder::create(0.5), x); - auto right = add(IrBuilder::create(1.0), tanh_inner); + auto inner = mul(beta, add(x, mul(kappa, x_cube))); + auto tanh_inner = tanh(inner); - auto left_derivative = mul(IrBuilder::create(0.5), right); + auto left = mul(half, x); + auto right = add(one, tanh_inner); - auto tanh_inner_sq = mul(tanh_inner, tanh_inner); - auto tanh_derivative = sub(IrBuilder::create(1.0), tanh_inner_sq); + auto left_derivative = mul(half, right); - auto constant_mul_x_sq = - mul(IrBuilder::create(kBeta * 3 * kKappa), x_sq); - auto inner_derivative = add(IrBuilder::create(kBeta), constant_mul_x_sq); - auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative)); + auto tanh_derivative = sub(one, mul(tanh_inner, tanh_inner)); + auto inner_derivative = + mul(beta, add(one, mul(mul(IrBuilder::create(3.0), kappa), x_sq))); + auto right_derivative = mul(mul(left, tanh_derivative), inner_derivative); auto dx = mul(dy, add(left_derivative, right_derivative)); return dx; diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index f54a883a29a..1ad738479c0 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -56,33 +57,58 @@ class DistributedTransformerTest }; namespace { +// testValidate doesn't work out of the box due to #2906, so I had to manually +// specify the absolute tolerances. The atols passed in are tuned for bfloat, +// the least precise dtype. They can probably be made stricter for other +// dtypes. void validate( - std::vector expected_out, - std::vector out) { - EXPECT_EQ(expected_out.size(), out.size()); - for (auto i : c10::irange(out.size())) { + const std::vector& expected_outputs, + const std::vector& outputs, + const std::vector& atols) { + using testing::SizeIs; + const auto num_outputs = outputs.size(); + ASSERT_THAT(expected_outputs, SizeIs(num_outputs)); + ASSERT_THAT(atols, SizeIs(num_outputs)); + + for (const auto i : c10::irange(num_outputs)) { // allclose can catch this as well. However, it would throw an exception, // not showing which output was problematic. - ASSERT_EQ(out[i].dtype(), expected_out[i].dtype()) + ASSERT_EQ(outputs[i].dtype(), expected_outputs[i].dtype()) << "Output " << i << " has a mismatching data type."; - // Note: Scaling tolerance up since the error accumulates across ops - // BFloat16 error is quite high, but the program has been verified with - // double precision to be logically correct. - const double atol = 0.075 * (i + 1); - const double rtol = 1.6e-2; - auto generate_comparison_details = [](at::Tensor out, - at::Tensor expected_out, + const double atol = atols[i]; + // These default rtols are copied from + // https://github.com/pytorch/pytorch/blob/951c21d6790334d57862e94a3f582ac724147a53/torch/testing/_comparison.py#L65-L73. + double rtol; + switch (outputs[i].scalar_type()) { + case at::kBFloat16: + rtol = 1.6e-2; + break; + case at::kHalf: + rtol = 1e-3; + break; + case at::kFloat: + rtol = 1.3e-6; + break; + default: + rtol = 0.0; + break; + } + + auto generate_comparison_details = [](at::Tensor expected_out, + at::Tensor out, double atol, double rtol) -> std::string { std::ostringstream oss; auto error = (out - expected_out).abs(); - auto max_error = error.max().item().to(); auto max_relative_error = - max_error / expected_out.abs().max().item().to(); + (error.max() / expected_out.abs().max()).item().to(); auto error_count = - at::sum(error >= (atol + expected_out.abs() * rtol)).item(); - indent(oss, 1) << "max error: " << max_error << std::endl; + at::sum(error >= atol + expected_out.abs() * rtol).item(); + indent(oss, 1) + << "max absolute error under rtol: " + << (error - expected_out.abs() * rtol).max().item().to() + << std::endl; indent(oss, 1) << "max relative error: " << max_relative_error << std::endl; indent(oss, 1) << "failing elements: " << error_count << ", " @@ -91,9 +117,11 @@ void validate( return oss.str(); }; - EXPECT_TRUE(out[i].allclose(expected_out[i], rtol, atol)) - << "Output " << i << " mismatches:" << std::endl - << generate_comparison_details(out[i], expected_out[i], atol, rtol); + EXPECT_TRUE(outputs[i].allclose(expected_outputs[i], rtol, atol)) + << "Output " << i << " mismatches with atol " << atol << ":" + << std::endl + << generate_comparison_details( + expected_outputs[i], outputs[i], atol, rtol); } } @@ -143,7 +171,7 @@ std::vector reference_mlp_backwards( at::Tensor w1) { auto at_dtype = w0.dtype(); // recompute activations - auto linear0 = at::matmul(x, w0).add(b0).to(at::kFloat); + auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0; auto gelu = at::gelu(linear0, "tanh"); // backwards pass @@ -183,7 +211,7 @@ std::vector reference_mha_backwards( at::Tensor w1) { auto at_dtype = w0.dtype(); // recompute up to sdpa - auto linear0 = at::matmul(x, w0).add(b0).view({B, S, 3 * E}); + auto linear0 = (at::matmul(x, w0).to(at::kFloat) + b0).view({B, S, 3 * E}); auto qkv = linear0.split(E, /*dim=*/-1); for (auto i = 0; i < 3; i++) { qkv[i] = qkv[i].reshape({B, S, H, E / H}).transpose(1, 2).to(at_dtype); @@ -428,8 +456,7 @@ std::vector mlp_backwards( // Activation recomputation TensorView* matmul0 = matmul(x, w0); TensorView* b0_bcast = broadcast(b0, {false, true, false}); - TensorView* linear0 = add(matmul0, b0_bcast); - linear0 = castOp(DataType::Float, linear0); + TensorView* linear0 = add(matmul0, b0_bcast); // add generates float. TensorView* gelu = tanh_gelu(linear0); gelu = castOp(dtype, gelu); @@ -651,7 +678,7 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); auto outputs = fec.runFusionWithInputs(inputs); - validate(expected_outputs, outputs); + validate(expected_outputs, outputs, {0.01, 0.01, 0.01, 0.01}); } TEST_P(DistributedTransformerTest, MultiheadAttention) { @@ -707,8 +734,8 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); - auto out = fec.runFusionWithInputs(inputs); - validate(expected_outputs, out); + auto outputs = fec.runFusionWithInputs(inputs); + validate(expected_outputs, outputs, {0.02, 0.01, 0.01, 0.01}); } TEST_P(DistributedTransformerTest, MLP_Backward) { @@ -776,7 +803,7 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { FusionExecutorCache fec(std::move(fusion)); auto outputs = fec.runFusionWithInputs(inputs); - validate(expected_outputs, outputs); + validate(expected_outputs, outputs, {1e-5, 0.2, 1e-5, 0.01, 0.2, 0.01, 0.02}); } TEST_P(DistributedTransformerTest, MHA_Backward) { @@ -872,7 +899,10 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); auto out = fec.runFusionWithInputs(inputs); - validate(expected_outputs, out); + validate( + expected_outputs, + out, + {1e-5, 0.02, 1e-5, 1e-4, 1e-4, 0.1, 0.1, 0.1, 0.01}); } TEST_P(DistributedTransformerTest, Forward) { @@ -981,7 +1011,7 @@ TEST_P(DistributedTransformerTest, Forward) { FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); auto outputs = fec.runFusionWithInputs(inputs); - validate(expected_outputs, outputs); + validate(expected_outputs, outputs, {1e-5, 0.01, 0.01, 0.02, 0.02}); } INSTANTIATE_TEST_SUITE_P( diff --git a/tests/cpp/validator.cpp b/tests/cpp/validator.cpp index 36bd8555146..6296a0da5e8 100644 --- a/tests/cpp/validator.cpp +++ b/tests/cpp/validator.cpp @@ -121,7 +121,7 @@ void testValidate( line_number, " in file ", file_name, - ".\n Detected abs error of: ", + ".\n Detected max abs error of: ", aten_output_tensor.sub(fusion_output_tensor) .abs() .max()