Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

group norm segmented into pointwise + persistent + pointwise #2375

Closed
liqiangxl opened this issue Jun 10, 2024 · 18 comments
Closed

group norm segmented into pointwise + persistent + pointwise #2375

liqiangxl opened this issue Jun 10, 2024 · 18 comments
Labels
allocation domain issues related to allocation domain support on hold This issue should be revisited in the future Op Support

Comments

@liqiangxl
Copy link
Collaborator

Group norm is calculated as:

x0 = [N, C, H, W]
x1 = x0.cast(fp32).reshape(N, C, H, W)  --> (N, G, C/G, H, W) 
x2 = x1 / x1.sum(C/G, H, W)
x3 = x2.reshape(N, G, C/G, H, W) --> (N, C, H, W)
x4 = w*x3 + b

Due to the two reshapes, normalization scheduler rejects the unsegmented fusion and then it is segmented into three sub-fusions:
(1) pointwise doing cast + reshape
(2) normalization
(3) pointwise doing reshape + scale & bias

Reproduce (modified from apex group norm implementation and apex group norm test ):

import torch
import thunder

def torch_group_norm(x, g, w, b, eps, act=""):
    xdtype, wdtype = x.dtype, w.dtype
    if xdtype != wdtype:
        x = x.to(dtype=wdtype)
    y = torch.nn.functional.group_norm(x, g, w, b, eps)
    if act in ["silu", "swish"]:
        y = torch.nn.functional.silu(y)
    if xdtype != wdtype and y.dtype != xdtype:
        y = y.to(dtype=xdtype)
    return y

def verify_group_norm(N=32,
                      C=128,
                      H=256,
                      W=256,
                      G=32,
                      xdtype=torch.float16,
                      wdtype=torch.float32,
                      eps=1e-5,
                      memory_format=torch.channels_last,
                      device='cuda',
                      act=""):
    # create data
    x_shape = (N, C, H, W)
    w_shape = (C,)
    weight = torch.rand(w_shape,
                        dtype=wdtype,
                        device='cuda',
                        requires_grad=True)
    bias = torch.rand(w_shape,
                      dtype=wdtype,
                      device='cuda',
                      requires_grad=True)
    x = torch.randn(x_shape, dtype=xdtype, device='cuda')
    x = x.to(memory_format=memory_format)
    x.requires_grad_(True)
    thunder_group_norm = thunder.jit(torch_group_norm)
    y_torch = torch_group_norm(x, G, weight, bias, eps, act)
    y_thunder = thunder_group_norm(x, G, weight, bias, eps, act)
    # compare
    torch.testing.assert_close(y_thunder, y_torch, atol=4e-2, rtol=0)

# NVFUSER_DUMP=scheduler_params,cuda_to_file,fusion_ir_preseg,python_definition python group_norm.py 2>&1 |tee 1.log
if __name__ == "__main__":
  verify_group_norm(N=2, C=128, H=16, W=16)

@naoyam
Copy link
Collaborator

naoyam commented Jun 10, 2024

I wonder why (1) and (2) are not fused. I think an easy solution for the second reshape would move it to the end of the fusion and turn it to a meta-data operation.

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Jun 10, 2024

Due to reductionInterferingView check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser uses Inner Persistent but got err msg Merging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}.
We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?

@naoyam
Copy link
Collaborator

naoyam commented Jun 10, 2024

Due to reductionInterferingView check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser uses Inner Persistent but got err msg Merging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}. We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?

There's only one normalization in (1) and (2), so I assume x1.sum is used as the reference for the segment. Which tensor does this conflict come from?

@liqiangxl
Copy link
Collaborator Author

Looks like thunder did some optimizations of the captured graph. Specifically, it does reshape of input before cast to fp32. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  t0 = torch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t0 = ltorch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
      # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
  [t16, t5, t9] = nvFusion0(t0)
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Jun 13, 2024

Thanks for the suggestion @jjsjann123 and @kevinstephano, nvFuser gets two reshapes if thunder.jit(torch_group_norm, nv_enable_bookend=False)

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  [t16, t5, t9] = nvFusion0(x)
    # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t5, t9, x), (0,))

If we change these two reshapes to reshape of input and reshape before output, and revise MarkAliasesPreparePass, nvFuser can change from 3 kernels to two no ops and a normalization kernel. I added a draft PR (#2405 ) with three test cases:

  1. GroupNormOriginal: segment into 3 kernels
  2. GroupNormReshapeMovedToInputOutputNoWeightBias: one kernel
  3. GroupNormReshapeMovedToInputOutput: one kernel

So the proposed plan is a 2-steps approach:

@liqiangxl
Copy link
Collaborator Author

Add original related issue. Lightning-AI/lightning-thunder#468

@wujingyue
Copy link
Collaborator

Unfortunately, the current alias analysis wasn't built for tracking aliases involving intermediates. When I wrote it, I didn't see a strong case for dealing with intermediates and expected intermediates to be fused into a kernel and become cheap index calculation. However, in this case, it appears that the two reshapes have caused the normalization scheduler to bail out...

#2405 sounds like a reasonable extension -- add segment_set after input reshapes and before output reshapes. Caveat: there are some tricky patterns that we'll need to consider, e.g., if the original input is used by an operation other than the reshape, we probably shouldn't segment out the reshape. But we can leave these important implementation details to PR discussion.

That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.

@wujingyue
Copy link
Collaborator

Thanks for the write-up, @liqiangxl! The problem looks quite clear to me even though I was unsure about the solution.

@liqiangxl
Copy link
Collaborator Author

That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.

It should be able to handle the first reshape (which only includes split of ID), but for the second reshape (which is merge of an iter ID and a reduction ID) needs a lot of work. So another option is we can extend current reduction/normalization scheduler to handle some kinds of reshapes so the pre-segment optimizaiton pass only needs to process some specfic types of reshape e.g. reshape includes a merge of iter ID and redu ID.

@liqiangxl
Copy link
Collaborator Author

@jjsjann123 , @wujingyue , and @naoyam
Thanks for the helpful discussions.
Here is a summary of the approach :

@naoyam
Copy link
Collaborator

naoyam commented Jun 24, 2024

Looks like thunder did some optimizations of the captured graph. Specifically, it does reshape of input before cast to fp32. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  t0 = torch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t0 = ltorch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
      # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
  [t16, t5, t9] = nvFusion0(t0)
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())

Trying to catch up what's going on here, but still confused why these two are not fused:

(1) pointwise doing cast + reshape
(2) normalization

Doesn't the second reshape belong to the third segment?

(3) pointwise doing reshape + scale & bias

@wujingyue
Copy link
Collaborator

Here is a summary of the approach

LGTM. I'll defer to others whether we should fix the scheduler(s) to accept both types of reshapes. That sounds like the right fix to me even though it may take longer.

@liqiangxl
Copy link
Collaborator Author

Trying to catch up what's going on here, but still confused why these two are not fused:

(1) pointwise doing cast + reshape
(2) normalization

Rejected by reductionInterferingView(), it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.

For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.

  • (1) They are grouped into 2 groups **{i2/32}, {i3}, {i4}** and {i0}, {32}.
  • (2) **{i2/32}**, {32}, {i2} are in the same entry of the disjoint sets (why?)
  • (3) **{i2/32}** is in a different group from **{32}**
  • (4)reductionInterferingView() returns false and fusion rejected.

I think we can safely skip these checks if the reshape only involves split since it won't cause the merge of iter ID and redu ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks.
Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.

Doesn't the second reshape belong to the third segment?

(3) pointwise doing reshape + scale & bias

Yes.

@wujingyue
Copy link
Collaborator

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

@liqiangxl
Copy link
Collaborator Author

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.

@naoyam
Copy link
Collaborator

naoyam commented Jun 24, 2024

Trying to catch up what's going on here, but still confused why these two are not fused:
(1) pointwise doing cast + reshape
(2) normalization

Rejected by reductionInterferingView(), it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.

For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.

  • (1) They are grouped into 2 groups **{i2/32}, {i3}, {i4}** and {i0}, {32}.
  • (2) **{i2/32}**, {32}, {i2} are in the same entry of the disjoint sets (why?)
  • (3) **{i2/32}** is in a different group from **{32}**
  • (4)reductionInterferingView() returns false and fusion rejected.

I think we can safely skip these checks if the reshape only involves split since it won't cause the merge of iter ID and redu ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks. Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.

Doesn't the second reshape belong to the third segment?
(3) pointwise doing reshape + scale & bias

Yes.

I see. I was only thinking about requiresForwardViewReplay, but it now makes sense.

Have you thought about using IdModel to rewrite reductionInterferingView? I think it'd be relatively straightforward with an ID graph. All we need to see is if any use of the ID groups of the reference tensor could merge multiple groups, which would be a simple graph traversal. It should naturally handle Merge as well.

@jjsjann123
Copy link
Collaborator

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.

Agree that the proper approach is that nvfuser should be able to handle graph level optimization and re-order some trivial reshape to allow better fusion. We can discuss the priority on that.

Meanwhile, the last reshape in thunder might be an easier thing to change with some slightly ugly code: https://github.com/Lightning-AI/lightning-thunder/blob/14e6c9b67eb038ab28a192cd381bc183b77e8f81/thunder/torch/__init__.py#L4406-L4420

liqiangxl added a commit that referenced this issue Jul 19, 2024
… segmentation issue #2375 (#2405)

This PR is step-2 to solve #2375 
Allow output to alias intermediate when:
(1) called from pre segment pass
(2) the `op` that preduces `output` interfer with reduction, see
`outputInterferingReduction`

After this PR, the last reshape (which produces the ouput tensor) in
group norm is changed to a no-op, see newly added test
`OutputAliasIntermediate`

Util function `getRepresentativeReductionTv` is added for convenience
and will be used to simipify reduction/normalization schedulers.

---------

Co-authored-by: Jingyue Wu <[email protected]>
liqiangxl added a commit that referenced this issue Jul 20, 2024
…ops has ID merges (#2583)

**Background**
This PR is separated from #2437 by removing all commits using IdModels.
I separate it out so we can merge it in the main branch and moving to
step-2 of the group norm issue.
The work to rewrite `reductionInterferingView` using IdModels will be
continued in #2437.

**---- PR description---**
**Issue:** #2375
**Fix:** When the reshape transformations from root domains to logical
domains consists only split, it is safe for the reduction scheduler to
merge any two reduction domains. We can skip reductionInterferingView
check.

**Results:**
With this PR, the reshape of input ([N,C,H,W] -> [N, G, C/G, H, W]) can
be fused with normalization into one kernel.
Group norm splits into 2 kernels instead of 3.

Following works:
(1) It also seems safe if the reshape transformations from root domains
to logical domains contains a merge, but all of the merged IDs are
mapped to reduction dims or iteration dims. Continued in #2437
wujingyue added a commit that referenced this issue Jul 21, 2024
wujingyue added a commit that referenced this issue Jul 21, 2024
wujingyue added a commit that referenced this issue Jul 26, 2024
wujingyue added a commit that referenced this issue Jul 27, 2024
@liqiangxl liqiangxl added allocation domain issues related to allocation domain support enhancement New feature or request labels Oct 30, 2024
@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Oct 30, 2024

The remaining issue is for channel last format, current reduction/normalization scheduler didn't consider allocation domains.

@kevinstephano kevinstephano added Triage Op Support and removed enhancement New feature or request labels Oct 30, 2024
@liqiangxl liqiangxl added on hold This issue should be revisited in the future and removed Triage labels Nov 5, 2024
@liqiangxl liqiangxl closed this as not planned Won't fix, can't repro, duplicate, stale Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
allocation domain issues related to allocation domain support on hold This issue should be revisited in the future Op Support
Projects
None yet
Development

No branches or pull requests

5 participants