Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize fmax with NAN #319

Open
naoyam opened this issue May 10, 2023 · 12 comments · Fixed by #329
Open

Optimize fmax with NAN #319

naoyam opened this issue May 10, 2023 · 12 comments · Fixed by #329
Assignees
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented May 10, 2023

Fp max reductions would typically look like:

  for(nvfuser_index_t i154 = 0; i154 < 8; ++i154) {
    int i299;
    i299 = 4 * i154;
#pragma unroll
      for(nvfuser_index_t i156 = 0; i156 < 4; ++i156) {
        T29[0] = fmax(
            T29[0],
            T24[(i299 + i156)]);
      }
  }

Here, fmax is is not just fmaxf, but it also incurs two more comparisons in case the arguments are NAN: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#LL102C1-L111C2

This could be translated as:

  bool is_nan = false;
#pragma unroll
  for(nvfuser_index_t i154 = 0; i154 < 8; ++i154) {
    int i299;
    i299 = 4 * i154;
#pragma unroll
      for(nvfuser_index_t i156 = 0; i156 < 4; ++i156) {
#if 0
        T29[0] = fmax(
            T29[0],
            T24[(i299 + i156)]);
#else
        T29[0] = T29[0] > T24[(i299 + i156)] ? T29[0] : T24[(i299 + i156)];
        is_nan = is_nan || isnan(T24[(i299 + i156)]);
#endif
      }
  }
  if (is_nan) {
    T29[0] = NAN;
  }

In the case of cross entropy loss (#278), I observed 20% speedup on A100.

I think this translation could be done automatically as part of lowering. See the translation for welford vectorization.

@jacobhinkle
Copy link
Collaborator

This is a good idea to speed up every call to max (and min). Furthermore in the particular case of softmax I think the nan checks could be avoided entirely since even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically. In that case we'd prefer a true fmaxf-style max computation with no nan check (like torch.fmax instead of torch.maximum). We could provide two different binary ops for each of min and max and add a propagate_nan=true option to the max and min functions, which we would set to false for softmax and log_softmax.

@jacobhinkle
Copy link
Collaborator

We could even determine whether nan checks can be skipped automatically without user input, but it involves a bit of a complicated traversal. We would visit each min or max op or reduction and mark the result as unchecked and the input as propagated, then move downstream from further uses of the input and from the min/max output, propagating the tags. When a pointwise binary or ternary op other than min or max has an unchecked input and a propagated input, it resolves the might-be-unchecked flag and becomes properly propagated. If we reach an output or another reduction we stop and if any output is not marked as propagated it means we can't skip nan checks in the original op.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented May 11, 2023

For the propagation analysis, consider this simple example with no reductions (if we had the non-reduction maximum/minimum ops that do nan-checks):

auto tv2 = maximum(tv0, IrBuilder::create<Float>(0.0));
auto tv3 = add(tv0, tv1);
auto tv4 = minimum(tv3, IrBuilder::create<Float>(10.0));
auto tv5 = sub(tv4, tv2);
fusion->addOutput(tv5);

Currently, we would do nan checks when computing both tv2 and tv4, but we could eliminate the check on tv2 since unpropagated nans would be resolved in tv5. The analysis is similar but a little more complicated when there are reductions and broadcasts like in softmax.

@jacobhinkle
Copy link
Collaborator

For the immediate task at hand, I believe we can remove one nan check more easily by removing this branch: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#L106-L107. If a == a, then a > b will be false if b != b so the else branch subsumes the b != b branch. Similarly, for fmin we can remove the a != a branch. Doing this, we remove one nan check per reduction element, so it should be equivalent in speed to code in the description of this issue, but doesn't necessitate any codegen changes. As mentioned in my other comments, we can probably remove all the nan checks but I will leave that for a separate PR.

@naoyam
Copy link
Collaborator Author

naoyam commented May 11, 2023

even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically.

Why? Can you please elaborate?

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented May 11, 2023

Sure. For log_softmax, we do something like the following:

auto a_max = max(a, {1});
auto a_sub = sub(a, broadcast(a_max, {false, true}));
auto a_exp = exp(a_sub);
auto a_lse = log(sum(a_exp, {1}))
auto log_softmax = sub(a_sub, broadcast(a_lse, {false, true}));

Imagine there is a single nan in some row of a. Ordinarily, this would mean a_max is nan, which would cascade down so that the whole row of a_sub is nan, and ultimately the whole row of log_softmax is nan.

Now imagine we don't check nans for max, so that a_max is not nan in that row for which a contains a nan. Then a_sub contains just that single nan, as does a_exp, since sub and exp propagate nans. But to compute the logsumexp a_lse, we sum a_exp. That sum propagates nans, so that a_lse, and hence log_softmax is nan.

jacobhinkle added a commit that referenced this issue May 11, 2023
Fixes #319. This implements a minimal fix, in that it does not modify
the reduction as stated in the issue, but it removes one nan-check for
each call to fmin/fmax. We may further want to disable nan-checking
completely, but that will be implemented in another PR.
@naoyam
Copy link
Collaborator Author

naoyam commented May 11, 2023

Oh, I see, very interesting. Have never heard of such a data flow analysis for NANs.

In terms of actual benefits, the transformation as I showed above may be just enough and easier to implement. I saw it got almost the same performance with fmaxf, likely because the root cause of the overhead would be the nan-check branch, which is almost completely eliminated.

@jacobhinkle
Copy link
Collaborator

Interesting. Have you tried it without any nan checking? I am not sure how fmaxf works internally but it seems like it would need to branch since it guarantees to return non-nan in case only one input is nan. On the other hand, without nan checks i.e. just a > b ? a : b might return nan or not, we don't care!

@naoyam
Copy link
Collaborator Author

naoyam commented May 11, 2023

Sorry, I meant I compared the performance with just a > b ? a : b.

@jacobhinkle
Copy link
Collaborator

With #329, each call to fmax will also only have one isnan check, but it comes before a > b instead of after. So shouldn't the #if 0 block now give roughly the same performance as the #else code above?

@naoyam
Copy link
Collaborator Author

naoyam commented May 11, 2023

I actually tried that as well, and it was about the middle of the two cases with about 10% improvement.

@jacobhinkle
Copy link
Collaborator

Interesting! Maybe we should re-open and I'll get to work on the original approach.

@jacobhinkle jacobhinkle reopened this May 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants