-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize fmax with NAN #319
Comments
This is a good idea to speed up every call to |
We could even determine whether nan checks can be skipped automatically without user input, but it involves a bit of a complicated traversal. We would visit each min or max op or reduction and mark the result as |
For the propagation analysis, consider this simple example with no reductions (if we had the non-reduction auto tv2 = maximum(tv0, IrBuilder::create<Float>(0.0));
auto tv3 = add(tv0, tv1);
auto tv4 = minimum(tv3, IrBuilder::create<Float>(10.0));
auto tv5 = sub(tv4, tv2);
fusion->addOutput(tv5); Currently, we would do nan checks when computing both |
For the immediate task at hand, I believe we can remove one nan check more easily by removing this branch: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#L106-L107. If |
Why? Can you please elaborate? |
Sure. For auto a_max = max(a, {1});
auto a_sub = sub(a, broadcast(a_max, {false, true}));
auto a_exp = exp(a_sub);
auto a_lse = log(sum(a_exp, {1}))
auto log_softmax = sub(a_sub, broadcast(a_lse, {false, true})); Imagine there is a single nan in some row of Now imagine we don't check nans for max, so that |
Fixes #319. This implements a minimal fix, in that it does not modify the reduction as stated in the issue, but it removes one nan-check for each call to fmin/fmax. We may further want to disable nan-checking completely, but that will be implemented in another PR.
Oh, I see, very interesting. Have never heard of such a data flow analysis for NANs. In terms of actual benefits, the transformation as I showed above may be just enough and easier to implement. I saw it got almost the same performance with |
Interesting. Have you tried it without any nan checking? I am not sure how |
Sorry, I meant I compared the performance with just |
With #329, each call to |
I actually tried that as well, and it was about the middle of the two cases with about 10% improvement. |
Interesting! Maybe we should re-open and I'll get to work on the original approach. |
Fp max reductions would typically look like:
Here, fmax is is not just
fmaxf
, but it also incurs two more comparisons in case the arguments are NAN: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#LL102C1-L111C2This could be translated as:
In the case of cross entropy loss (#278), I observed 20% speedup on A100.
I think this translation could be done automatically as part of lowering. See the translation for welford vectorization.
The text was updated successfully, but these errors were encountered: