Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reference implementations and improve validation. #2905

Merged
merged 11 commits into from
Sep 16, 2024
Merged

Fix reference implementations and improve validation. #2905

merged 11 commits into from
Sep 16, 2024

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Sep 4, 2024

  1. Fix some reference implementations to cast in the same way as nvFuser.
  2. Allow validate to take tolerances.
  3. Refine the tolerances so the tests pass up to 6 GPUs.

@wujingyue
Copy link
Collaborator Author

!build

@wujingyue wujingyue requested a review from cowanmeg September 4, 2024 21:15
auto gelu = at::gelu(linear0, "tanh");
auto linear1 =
at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1.to(at::kFloat);
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A no-op change. + in PyTorch promotes data types.

auto linear1 =
at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1.to(at::kFloat);
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0;
auto gelu = at::gelu(linear0, "tanh").to(at_dtype);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary for at::allclose to pass because mlp generates the output of at_dtype.

auto outputs = runtime.runWithInput(inputs);
validate(expected_outputs, outputs);
auto outputs = fec.runFusionWithInputs(inputs);
ASSERT_EQ(outputs.size(), expected_outputs.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use the existing validate and change the tolerance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to do that later. This PR only fixes the ref of MLP, and validate is used by other tests as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could parameterize it with the lower atol. I imagine if we want tighter bounds for all the tests each one will have it's own value that is hard to predict without manually trying

@wujingyue wujingyue requested a review from cowanmeg September 5, 2024 23:00
@wujingyue wujingyue changed the base branch from main to wjy/more September 6, 2024 17:11
@wujingyue wujingyue marked this pull request as draft September 6, 2024 17:11
Base automatically changed from wjy/more to main September 6, 2024 20:10
@wujingyue wujingyue changed the title Use FusionExecutorCache for MLP_Layer tests. Fix reference implementations and improve validation. Sep 7, 2024
@wujingyue wujingyue marked this pull request as ready for review September 7, 2024 05:30
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspected some inaccuracy and made these changes. Although they turned out not to affect accuracy, I left them in the PR as general cleanups.

@wujingyue
Copy link
Collaborator Author

!build

<< "Output " << i << " has a mismatching data type.";

// Note: Scaling tolerance up since the error accumulates across ops
// BFloat16 error is quite high, but the program has been verified with
// double precision to be logically correct.
const double atol = 0.075 * (i + 1);
const double rtol = 1.6e-2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove rtol? Some of the absolute errors are 1e-5 which is a very small error of margin for bfloat16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be beneficial to keep some form of default value, since tests are still getting added this adds to a lot of effort to get a simple example working.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it because I found I ended up always setting rtol to 0 :)

I'd love to keep some default value to, as you said, minimize the effort to get a simple example working. What would be a good default value? The old atol=0.075*(i+1) is too relaxed after I fixed the reference implementation. The *(i+1) part is also problematic because the output number doesn't necessarily match the layer number.

How about this:

  1. Hardcode rtol to 0.016. I'll try to finetune this value a little bit -- 0.016 sounds like a large default rtol to start with.
  2. Still require the caller to provide a list of per-output atols, because it's hard to fine a good default and each test seems to require something different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done -- added default rtols.


auto tanh_inner_sq = mul(tanh_inner, tanh_inner);
auto tanh_derivative = sub(IrBuilder::create<Val>(1.0), tanh_inner_sq);
auto left_derivative = mul(half, right);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this rewrite necessary? You mentioned it didn't affect accuracy like you had hypothesized so I'm wondering what this brings us?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary and I'm happy to revert it. I kept it for only two reasons:

  1. Code becomes shorter.
  2. Code matches the PyTorch implementation more closely, which I hypothesized would help accuracy.

@wujingyue wujingyue requested a review from cowanmeg September 7, 2024 21:38
@wujingyue wujingyue added the bug Something isn't working label Sep 7, 2024
@wujingyue
Copy link
Collaborator Author

!build

@wujingyue wujingyue merged commit 3fcde3b into main Sep 16, 2024
34 of 36 checks passed
@wujingyue wujingyue deleted the wjy/fec branch September 16, 2024 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants