Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add knobs control inner dim unroll and outer dim unroll in pointwise scheduler #3275

Merged
merged 27 commits into from
Nov 2, 2024

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Oct 25, 2024

What's in this PR?
(1) Added two knobs to control unroll in inner dim and outer dim for pointwise scheduler
(2) The original untoll knob which applies to outer dim is removed.
(3) Extended test UnrollOnTopOfVectorize to test 8 different combinations of vectorization, inner unroll, and outer unroll.
(4) Neither inner unroll nor outer unroll is used in the heuristics. They are always 1 unless vectorization == 1, in that case, inner unroll is used.
(5) If inner or outer unroll factor == 1, we won't split out an additional domain with size of 1.

Why?
These two knobs allows more performance optimizations, e.g. unroll in different dims based on broadcast dims.

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl force-pushed the llu/ps_unroll_inner_outer branch from ffd65d1 to 7e04577 Compare October 25, 2024 18:37
@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

1 similar comment
@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Oct 28, 2024

(1) diffs in nvfuser-ci/jit_codegen_diff_bench_17_5/5 — Failing after 43 minutes https://nv/e2E/118807639
This is due to the code change of If inner or outer unroll factor == 1, we won't split out an additional domain with size of 1.
before this PR, we split out an additional domain even when unroll factor == 1.

before: T4_l_float[ iblockIdx.x41{( ceilDiv(i2, blockDim.x) )}, iblockIdx.y47{( ceilDiv(( ceilDiv(( ceilDiv(i0, 1) ), 1) ), blockDim.y) )}, ithreadIdx.y48{blockDim.y}, iUS46{1}, iS44{1}, ithreadIdx.x42{blockDim.x} ]
After: T4_l_float[ iblockIdx.x35{( ceilDiv(i2, blockDim.x) )}, iblockIdx.y39{( ceilDiv(( ceilDiv(i0, 1) ), blockDim.y) )}, ithreadIdx.y40{blockDim.y}, iUS38{1}, ithreadIdx.x36{blockDim.x} ]

The additional domain iS44{1}, no longer exist after this PR. This leads to code change for cpp benchmark case NvFuserScheduler_Broadcast_Inner_fp32/64/160/manual_time:

+    float T2[1];
+    T2[0]
+       = T5[0];
     float T4[1];
     T4[0] = 0;
     T4[0]
        = T0[i2];
-    float T2[1];
-    T2[0]
-       = T5[0];

(2) diffs in nvfuser-ci/jit_codegen_diff_17_5/7 — Failing after 7 minutes https://nv/e2E/118807632
Same reason as explained in (1)

(3) diffs in nvfuser-ci/jit_codegen_diff_17_6/7 — Failing after 58 minutes https://nv/e2E/118807633
Detected many test changes. This shouldn't happen since the base kernels are generated from current top of main branch is [628a47e3 Pointwise shouldn't check transpose scheduler (#3256)] , is this realted to a recent change of codediff script? @jacobhinkle

(4) diffs in nvfuser-ci/jit_codegen_diff_17_7/7 — Failing after 58 minutes https://nv/e2E/118807634
Same reason as explained in (3)

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl liqiangxl force-pushed the llu/ps_unroll_inner_outer branch from c739569 to b23cb41 Compare October 28, 2024 00:59
@liqiangxl liqiangxl marked this pull request as ready for review October 29, 2024 19:29
@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Oct 29, 2024

(3) diffs in nvfuser-ci/jit_codegen_diff_17_6/7 — Failing after 58 minutes https://nv/e2E/118807633
Detected many test changes. This shouldn't happen since the base kernels are generated from current top of main branch is [628a47e3 Pointwise shouldn't check transpose scheduler (#3256)] , is this realted to a recent change of codediff script? @jacobhinkle

Yes, we have identified that it is a serde issue. @naoyam confirmed a fix in #3283, we just need to turn that into a knob we can use inside CI see PR #3304. See also #3265. cc @rdspring1

@@ -640,7 +640,8 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.PARAM(PointwiseParams, split_grid_y_dim)
.PARAM(PointwiseParams, flip_grid_binding)
.PARAM(PointwiseParams, vectorization_factor)
.PARAM(PointwiseParams, unroll_factor);
.PARAM(PointwiseParams, unroll_factor_inner)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

You may want to change unroll_factor to unroll_factor_outer in https://github.com/NVIDIA/Fuser/blob/main/doc/dev/python_scheduling/autotune_pointwise.py#L92, so the script runs as-is?

reference_tv->split(0, pparams->unroll_factor);
// [o-remainder, Unroll| i-remainder, TIDx, Vect]
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(1, pparams->unroll_factor_inner);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are splitting on dimension 1? which is the TIDx here right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for 2D scheduler, start with [outer dim, inner dim], so here dimension 1 is i-remainder in [0-outer | 1-i-remainder, 2-TIDx, 3-Vect]. i-remainder means what is left after splitting out other dims, e.g. Vect, TIDx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a behavior change then.

If we look at the above commented code change, we are doing

-      reference_tv->split(0, pparams->unroll_factor);
-      // [o-remainder, Unroll| i-remainder, TIDx, Vect]
+      if (pparams->unroll_factor_inner > 1) {
+       reference_tv->split(1, pparams->unroll_factor_inner);

Which means the old behavior (outer unroll) is being updated to a default inner unroll instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Should assign unroll to inner dim only when the scheduler is 1D, for 2D should assign to outer dim.

    // for 1D scheduler, unroll the inner dimension
    // since there is no outer dimension.
    if (break_point == 0) {
      params->unroll_factor_inner = total_unroll;
      params->unroll_factor_outer = 1L;
    } else {
      // for 2D scheduler, unroll the outer dimension
      // to prioritize resue across different rows, will
      // be revised in heuristics tuning, e.g. unroll different
      // dims based on the broadcast dimension.
      params->unroll_factor_inner = 1L;
      params->unroll_factor_outer = total_unroll;
    }

max_vect_unroll_factor, params->vectorization_factor);
params->unroll_factor_inner = total_unroll;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this PR shouldn't impose any functional changes. So I would expect all old use of params->unroll_factor to be replaced with params->unroll_factor_inner.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. here all the unroll factors go to unroll_factor_inner through params->unroll_factor_inner = total_unroll;

if (pparams->unroll_factor_outer > 1) {
reference_tv->split(0, pparams->unroll_factor_outer);
}
// [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit lost about the notation here. What's o-Unroll | i-remainder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o represents outer dim and i represents inner dim. | sperates inner dim and outer dim. So here o-Unroll represents outer unroll and i-remainder means what is left in the inner dim after splitting out other domains, e.g. Vect, TIDx

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used o-Unroll and i-Unroll to distinguish between unroll in outer dim and inner dim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry I was totally not getting | part here. Now it reads clear to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments for clarity.

      // Here and in the following comments:
      // prefix [i] represents inner dimension
      // prefix [o] represents inner dimension
      // [|] separates the outer and inner dimensions

@@ -822,7 +847,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// Threads
reference_tv->split(0, kThreadX);
// Unroll
reference_tv->split(0, pparams->unroll_factor);
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(0, pparams->unroll_factor_inner);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: we are not using unroll_factor_outer in this branch, is that expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this else branch is for 1D scheduler, all IDs are merged into 1 domain, there is no outer dim.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Since this isn't applying any functional change, should we double check the code diff just to be sure?

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

!test --diff-bench --diff

@liqiangxl
Copy link
Collaborator Author

Still not sure why code is changed, e.g. test_correctness_var_mean_float64 is a reduction, which shouldn't be changed by this PR. Let me close this PR and redo a new one. @jjsjann123 @jacobhinkle code diff test in #3311 seems work fine.

@jacobhinkle
Copy link
Collaborator

Still not sure why code is changed, e.g. test_correctness_var_mean_float64 is a reduction, which shouldn't be changed by this PR. Let me close this PR and redo a new one. @jjsjann123 @jacobhinkle code diff test in #3311 seems work fine.

I'm seeing that in all the PRs. I think there's something going on that is flipping the order of outputs in the generated kernel. It may or may not be related to serde.

@liqiangxl liqiangxl reopened this Nov 2, 2024
@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Nov 2, 2024

@jjsjann123 I am going to merge this PR after buid test. There are two types of code diffs.
(1) due to the change in this PR, If inner or outer unroll factor == 1, we won't split out an additional domain with size of 1. This removes an extra for-loop, it leads to different compute-at position and expr orders. Here is a case.
(2) as jacob said, something going on that is flipping the order of outputs in the generated kernel. Here is an example. I can't reproduce locally, so probabally related to CI scripts. The fusion also doesn't use pointwise scheduler.

@liqiangxl
Copy link
Collaborator Author

!build !test

@liqiangxl
Copy link
Collaborator Author

!tests

@liqiangxl liqiangxl merged commit c02e7ee into main Nov 2, 2024
47 checks passed
@liqiangxl liqiangxl deleted the llu/ps_unroll_inner_outer branch November 2, 2024 13:40
rdspring1 added a commit that referenced this pull request Nov 5, 2024
Fix the `autotune_pointwise` script which was broken by
#3275.
The earlier PR changed the pointwise setting from `unroll_factor` to
`inner_unroll_factor`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants