Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bad Performance with Segmentation and Pad Operation? #860

Closed
kevinstephano opened this issue Sep 10, 2023 · 1 comment
Closed

Bad Performance with Segmentation and Pad Operation? #860

kevinstephano opened this issue Sep 10, 2023 · 1 comment
Assignees

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Sep 10, 2023

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id8(fd : FusionDefinition) -> None :
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[False, False, False, True], dtype=DataType.BFloat16, is_cpu=False)
    T10 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[False, False, False, True], dtype=DataType.BFloat16, is_cpu=False)
    T11 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[False, False, False, True], dtype=DataType.BFloat16, is_cpu=False)
    T12 = fd.ops.permute(T11, dims=[0, 2, 1, 3])
    T13 = fd.ops.reshape(T12, original_shape=[16, 512, 25, 64], new_shape=[16, 512, 1600])
    T14 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
    T15 = fd.ops.reshape(T14, original_shape=[16, 512, 25, 64], new_shape=[16, 512, 1600])
    T16 = fd.ops.permute(T10, dims=[0, 2, 1, 3])
    T17 = fd.ops.reshape(T16, original_shape=[16, 512, 25, 64], new_shape=[16, 512, 1600])
    S18 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T19 = fd.ops.pad(T13, [3200, 0, 0, 0, 0, 0], S18)
    S20 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T21 = fd.ops.pad(T17, [1600, 1600, 0, 0, 0, 0], S20)
    T22 = fd.ops.cast(T19, dtype=DataType.Float)
    T23 = fd.ops.cast(T21, dtype=DataType.Float)
    T24 = fd.ops.add(T22, T23)
    S25 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T26 = fd.ops.pad(T15, [0, 3200, 0, 0, 0, 0], S25)
    T27 = fd.ops.cast(T26, dtype=DataType.Float)
    T28 = fd.ops.add(T24, T27)
    T29 = fd.ops.cast(T28, dtype=DataType.BFloat16)
    T30 = fd.ops.sum(T28, axes=[0, 1], keepdim=False, dtype=DataType.Null)
    T31 = fd.ops.cast(T30, dtype=DataType.BFloat16)
    fd.add_output(T31)
    fd.add_output(T29)

inputs = [
        torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
        torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
        torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
        ]

with FusionDefinition() as fd:
    nvfuser_fusion_id8(fd)

for _ in range(5):
    out = fd.execute(inputs)

This definition creates 4 kernels:

 2189792239          32160    1211     2  8192     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  CudaCodeGen::kernel1(CudaCodeGen::Tensor<CudaCodeGen::__bfloat, (int)4, (int)4>, CudaCodeGen::Tenso…
 2189825519          34784    1222     2  8192     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  CudaCodeGen::kernel2(CudaCodeGen::Tensor<CudaCodeGen::__bfloat, (int)4, (int)4>, CudaCodeGen::Tenso…
 2189861615          34048    1233     2  8192     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  CudaCodeGen::kernel3(CudaCodeGen::Tensor<CudaCodeGen::__bfloat, (int)4, (int)4>, CudaCodeGen::Tenso…
 2189896815           1216    1262     1     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<long>, at::detail::A…
 2189899279         288962    1268    38    22     1    64     8     1       64         0.000         0.002                                                     NVIDIA A100 80GB PCIe (0)    1     7  CudaCodeGen::kernel4(CudaCodeGen::Tensor<CudaCodeGen::__bfloat, (int)3, (int)3>, CudaCodeGen::Tenso…

Pytorch Eager Mode:

import torch

def func(t1, t2, t3) :
   t1_t = t1.transpose(1, 2)
   t2_t = t2.transpose(1, 2)
   t3_t = t3.transpose(1, 2)

   shape = t1_t.shape

   t1_r = t1_t.reshape(shape[0], shape[1], shape[2]*shape[3])
   t2_r = t2_t.reshape(shape[0], shape[1], shape[2]*shape[3])
   t3_r = t3_t.reshape(shape[0], shape[1], shape[2]*shape[3])

   return torch.cat((t1_r, t2_r, t3_r,), 2)

inputs = [
       torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
       torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
       torch.randn(16, 25, 512, 64, device='cuda', dtype=torch.bfloat16),
       ]

for _ in range(5):
   out = func(*inputs)
1324955633          67392     944  25600     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl<at::native::…
1325024337          67297     958  25600     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl<at::native::…
1325092754          67104     972  25600     1     1   128     1     1       16         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl<at::native::…
1325161202         176609     984    216     3     1   512     1     1       26         0.000         0.000                                                     NVIDIA A100 80GB PCIe (0)    1     7  void at::native::<unnamed>::CatArrayBatchedCopy<c10::BFloat16, unsigned int, (int)3, (int)128, (int…

nvFuser: 390us
vs
Pytorch Eager: 378us

@naoyam naoyam self-assigned this Sep 11, 2023
@kevinstephano
Copy link
Collaborator Author

Closing. Requires vectorization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants