From 07a7f332661aa7c4b933c74ffcf80bb8c5c125a1 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 28 Oct 2024 22:21:12 -0700 Subject: [PATCH] compile fwd fn in bwd benchmarks --- benchmarks/python/normalization.py | 5 ++++- .../python/test_dropout_layernorm_bwd.py | 22 ++++++++++++------- benchmarks/python/test_dropout_rmsnorm_bwd.py | 11 +++++++--- benchmarks/python/test_gelu_bwd.py | 9 ++++++-- .../python/test_huggingface_attn_bwd.py | 16 +++++++++----- benchmarks/python/test_layernorm_bwd.py | 19 +++++++++------- benchmarks/python/test_nanogpt_attn_bwd.py | 21 +++++++++++------- benchmarks/python/test_rmsnorm_bwd.py | 14 ++++++++---- benchmarks/python/test_scale_bias_relu_bwd.py | 11 +++++++--- benchmarks/python/test_silu_mul_bwd.py | 11 +++++++--- benchmarks/python/test_softmax_bwd.py | 18 +++++++-------- 11 files changed, 103 insertions(+), 54 deletions(-) diff --git a/benchmarks/python/normalization.py b/benchmarks/python/normalization.py index 8ec19ebf71d..8d1648529ed 100644 --- a/benchmarks/python/normalization.py +++ b/benchmarks/python/normalization.py @@ -489,12 +489,15 @@ def norm_bwd_baseline_benchmark( grads = grads.to(memory_format=torch.channels_last) norm_fwd_fn = batchnorm_fwd_fn if norm == "batch_norm" else instancenorm_fwd_fn + + # Compile the fwd fn for torchcompile + norm_fwd_fn = torch.compile(norm_fwd_fn) if compile else norm_fwd_fn output = norm_fwd_fn([inputs, weight, bias, running_mean, running_var]) # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=norm_bwd_iobytes(size, dtype, norm), ) diff --git a/benchmarks/python/test_dropout_layernorm_bwd.py b/benchmarks/python/test_dropout_layernorm_bwd.py index dcff2abb5ba..36a9b1a2c3a 100644 --- a/benchmarks/python/test_dropout_layernorm_bwd.py +++ b/benchmarks/python/test_dropout_layernorm_bwd.py @@ -207,17 +207,23 @@ def test_dropout_layernorm_bwd_baseline_benchmark( grads = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) bias = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) - - output = torch.nn.functional.layer_norm( - input2 + torch.nn.functional.dropout(input1, p=dropout_p), - normalized_shape=input1.shape[1:], - weight=weights, - bias=bias, - ) + + def dropout_layernorm_fwd(): + return torch.nn.functional.layer_norm( + input2 + torch.nn.functional.dropout(input1, p=dropout_p), + normalized_shape=input1.shape[1:], + weight=weights, + bias=bias, + ) + + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(dropout_layernorm_fwd) if compile else dropout_layernorm_fwd + output = fwd_fn() + # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=dropout_layernorm_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_dropout_rmsnorm_bwd.py b/benchmarks/python/test_dropout_rmsnorm_bwd.py index d103c17dcfa..370d8dfc807 100644 --- a/benchmarks/python/test_dropout_rmsnorm_bwd.py +++ b/benchmarks/python/test_dropout_rmsnorm_bwd.py @@ -186,12 +186,17 @@ def test_dropout_rmsnorm_bwd_baseline_benchmark( grads = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype) - x = input2 + torch.nn.functional.dropout(input1, p=dropout_p) - output = weights * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) + def dropout_rmsnorm_fwd(): + x = input2 + torch.nn.functional.dropout(input1, p=dropout_p) + output = weights * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) + return output + + fwd_fn = torch.compile(dropout_rmsnorm_fwd) else dropout_rmsnorm_fwd + output = fwd_fn() run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=dropout_rmsnorm_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_gelu_bwd.py b/benchmarks/python/test_gelu_bwd.py index a876f00d748..0524a06155e 100644 --- a/benchmarks/python/test_gelu_bwd.py +++ b/benchmarks/python/test_gelu_bwd.py @@ -102,10 +102,15 @@ def test_gelu_bwd_baseline_benchmark( inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) bias = torch.ones(size[-1], device="cuda", dtype=dtype) grads = torch.randn(size, device="cuda", dtype=dtype) - eager_output = torch.nn.functional.gelu(inputs + bias, approximate="tanh") + + def gelu_fwd(): + return torch.nn.functional.gelu(inputs + bias, approximate="tanh") + fwd_fn = torch.compile(gelu_fwd) if compile else gelu_fwd + eager_output = fwd_fn() + run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [eager_output, grads], iobytes=gelu_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_huggingface_attn_bwd.py b/benchmarks/python/test_huggingface_attn_bwd.py index b94c6a471c3..79154d4cdd0 100644 --- a/benchmarks/python/test_huggingface_attn_bwd.py +++ b/benchmarks/python/test_huggingface_attn_bwd.py @@ -126,16 +126,22 @@ def test_huggingface_attn_bwd_baseline_benchmark( attention_mask = torch.zeros( batch_size, nh, seq_len, seq_len, device="cuda", dtype=dtype ) - attn = (inputs + attention_mask).view(batch_size * nh, seq_len, seq_len) - attn = torch.nn.functional.softmax(attn, dim=-1) - output = torch.nn.functional.dropout(attn, p=dropout_p) - + + def huggingface_attn_fwd(): + attn = (inputs + attention_mask).view(batch_size * nh, seq_len, seq_len) + attn = torch.nn.functional.softmax(attn, dim=-1) + output = torch.nn.functional.dropout(attn, p=dropout_p) + return output + + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(huggingface_attn_fwd) if compile else huggingface_attn_fwd + output = fwd_fn() grads = torch.randn(batch_size * nh, seq_len, seq_len, device="cuda", dtype=dtype) # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=huggingface_attn_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_layernorm_bwd.py b/benchmarks/python/test_layernorm_bwd.py index 154dc74d8b8..9d95e0eaeba 100644 --- a/benchmarks/python/test_layernorm_bwd.py +++ b/benchmarks/python/test_layernorm_bwd.py @@ -163,17 +163,20 @@ def test_layernorm_bwd_baseline_benchmark( weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) bias = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) - output = torch.nn.functional.layer_norm( - inputs, - inputs.shape[1:], - weight=weights, - bias=bias, - ) - + def layernorm_fwd(): + return torch.nn.functional.layer_norm( + inputs, + inputs.shape[1:], + weight=weights, + bias=bias, + ) + fwd_fn = torch.compile(layernorm_fwd) if compile else layernorm_fwd + output = fwd_fn() + # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=layernorm_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_nanogpt_attn_bwd.py b/benchmarks/python/test_nanogpt_attn_bwd.py index 136429d475e..2efb8e7d58d 100644 --- a/benchmarks/python/test_nanogpt_attn_bwd.py +++ b/benchmarks/python/test_nanogpt_attn_bwd.py @@ -144,19 +144,24 @@ def test_nanogpt_attn_bwd_baseline_benchmark( 1, 1, seq_len, seq_len ) - # Compute output - hs = n_embd // nh - attn = inputs / (hs**0.5) - attn = attn.masked_fill(bias[:, :, :seq_len, :seq_len] == 0, float("-inf")) - attn = torch.nn.functional.softmax(attn, dim=-1) - output = torch.nn.functional.dropout(attn, p=dropout_p) - + def nanogpt_attn_fwd(): + # Compute output + hs = n_embd // nh + attn = inputs / (hs**0.5) + attn = attn.masked_fill(bias[:, :, :seq_len, :seq_len] == 0, float("-inf")) + attn = torch.nn.functional.softmax(attn, dim=-1) + output = torch.nn.functional.dropout(attn, p=dropout_p) + return output + + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(nanogpt_attn_fwd) if compile else nanogpt_attn_fwd + output = fwd_fn() grads = torch.randn(batch_size, nh, seq_len, seq_len, device="cuda", dtype=dtype) # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=nanogpt_attn_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_rmsnorm_bwd.py b/benchmarks/python/test_rmsnorm_bwd.py index 3076dd826bb..5216c5223fd 100644 --- a/benchmarks/python/test_rmsnorm_bwd.py +++ b/benchmarks/python/test_rmsnorm_bwd.py @@ -127,14 +127,20 @@ def test_rmsnorm_bwd_baseline_benchmark( grads = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) - squared_mean = (inputs**2).mean(1, keepdim=True) - rms_eps = torch.sqrt(squared_mean + 1e-5) - output = weights * (inputs / rms_eps) + def rmsnorm_fwd(): + squared_mean = (inputs**2).mean(1, keepdim=True) + rms_eps = torch.sqrt(squared_mean + 1e-5) + output = weights * (inputs / rms_eps) + return output + + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(rmsnorm_fwd) if compile else rmsnorm_fwd + output = fwd_fn() # Manually compute IOBytes: See PR #1725 run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [output, grads], iobytes=rmsnorm_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_scale_bias_relu_bwd.py b/benchmarks/python/test_scale_bias_relu_bwd.py index f2c75ef3971..425ffca3ddf 100644 --- a/benchmarks/python/test_scale_bias_relu_bwd.py +++ b/benchmarks/python/test_scale_bias_relu_bwd.py @@ -94,11 +94,16 @@ def test_sbr_bwd_baseline_benchmark( grads = torch.randn(*size, device="cuda", dtype=dtype) scale = torch.ones(size[-1], device="cuda", dtype=dtype) bias = torch.ones(size[-1], device="cuda", dtype=dtype) - eager_output = torch.nn.functional.relu(inputs * scale + bias) - + + def sbr_fwd(): + return torch.nn.functional.relu(inputs * scale + bias) + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(sbr_fwd) if compile else sbr_fwd + eager_output = sbr_fwd() + run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [eager_output, grads], iobytes=sbr_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_silu_mul_bwd.py b/benchmarks/python/test_silu_mul_bwd.py index 17fc57587cd..5faa4379d0d 100644 --- a/benchmarks/python/test_silu_mul_bwd.py +++ b/benchmarks/python/test_silu_mul_bwd.py @@ -93,11 +93,16 @@ def test_silu_mul_bwd_baseline_benchmark( x = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) y = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(*size, device="cuda", dtype=dtype) - eager_output = torch.nn.functional.silu(x) * y - + + def silu_mul_fwd(): + return torch.nn.functional.silu(x) * y + # Compile the fwd fn for torchcompile + fwd_fn = torch.compile(silu_mul_fwd) if compile else silu_mul_fwd + eager_output = fwd_fn() + run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + unary_bwd_torch, [eager_output, grads], iobytes=silu_mul_bwd_iobytes(size, dtype), ) diff --git a/benchmarks/python/test_softmax_bwd.py b/benchmarks/python/test_softmax_bwd.py index 085268a405d..fd015bcb4cb 100644 --- a/benchmarks/python/test_softmax_bwd.py +++ b/benchmarks/python/test_softmax_bwd.py @@ -4,7 +4,7 @@ import pytest from nvfuser import FusionDefinition, DataType from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype -from .core import run_benchmark, clear_dynamo_cache +from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES import numpy as np @@ -58,11 +58,6 @@ def softmax_bwd_fusion( fd.add_output(T19) -def unary_bwd_torch(inputs: list): # [in_tensor, output, grads] - inputs[1].backward(inputs[2], retain_graph=True) - return inputs[0].grad - - def softmax_bwd_iobytes(size: tuple, dtype: torch.dtype): # Total IO bytes = output + grad_out + grad_input return int(np.prod(size) * dtype.itemsize * 3) @@ -111,10 +106,15 @@ def test_softmax_bwd_baseline_benchmark( clear_dynamo_cache() input = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(size, device="cuda", dtype=dtype) - output = torch.nn.functional.softmax(input, dim=reduction_axis) + + def softmax_fwd(): + return torch.nn.functional.softmax(input, dim=reduction_axis) + fwd_fn = torch.compile(softmax_fwd) if compile else softmax_fwd + output = fwd_fn() + run_benchmark( benchmark, - torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, - [input, output, grads], + unary_bwd_torch, + [output, grads], iobytes=softmax_bwd_iobytes(size, dtype), )