-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix reference implementations and improve validation. #2905
Conversation
!build |
auto gelu = at::gelu(linear0, "tanh"); | ||
auto linear1 = | ||
at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1.to(at::kFloat); | ||
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A no-op change. +
in PyTorch promotes data types.
auto linear1 = | ||
at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1.to(at::kFloat); | ||
auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0; | ||
auto gelu = at::gelu(linear0, "tanh").to(at_dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary for at::allclose
to pass because mlp
generates the output of at_dtype
.
auto outputs = runtime.runWithInput(inputs); | ||
validate(expected_outputs, outputs); | ||
auto outputs = fec.runFusionWithInputs(inputs); | ||
ASSERT_EQ(outputs.size(), expected_outputs.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use the existing validate and change the tolerance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to do that later. This PR only fixes the ref of MLP, and validate
is used by other tests as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could parameterize it with the lower atol. I imagine if we want tighter bounds for all the tests each one will have it's own value that is hard to predict without manually trying
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspected some inaccuracy and made these changes. Although they turned out not to affect accuracy, I left them in the PR as general cleanups.
!build |
<< "Output " << i << " has a mismatching data type."; | ||
|
||
// Note: Scaling tolerance up since the error accumulates across ops | ||
// BFloat16 error is quite high, but the program has been verified with | ||
// double precision to be logically correct. | ||
const double atol = 0.075 * (i + 1); | ||
const double rtol = 1.6e-2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove rtol? Some of the absolute errors are 1e-5 which is a very small error of margin for bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be beneficial to keep some form of default value, since tests are still getting added this adds to a lot of effort to get a simple example working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it because I found I ended up always setting rtol to 0 :)
I'd love to keep some default value to, as you said, minimize the effort to get a simple example working. What would be a good default value? The old atol=0.075*(i+1)
is too relaxed after I fixed the reference implementation. The *(i+1)
part is also problematic because the output number doesn't necessarily match the layer number.
How about this:
- Hardcode rtol to 0.016. I'll try to finetune this value a little bit -- 0.016 sounds like a large default rtol to start with.
- Still require the caller to provide a list of per-output atols, because it's hard to fine a good default and each test seems to require something different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was supposed to have been taken from Pytorch. https://github.com/pytorch/pytorch/blob/042f2f7746a064f1527d95d1f1d712b4f0b34186/test/test_transformers.py#L85
But should be .0126.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done -- added default rtols.
|
||
auto tanh_inner_sq = mul(tanh_inner, tanh_inner); | ||
auto tanh_derivative = sub(IrBuilder::create<Val>(1.0), tanh_inner_sq); | ||
auto left_derivative = mul(half, right); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this rewrite necessary? You mentioned it didn't affect accuracy like you had hypothesized so I'm wondering what this brings us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not necessary and I'm happy to revert it. I kept it for only two reasons:
- Code becomes shorter.
- Code matches the PyTorch implementation more closely, which I hypothesized would help accuracy.
!build |
validate
to take tolerances.