Skip to content

Commit

Permalink
compile fwd fn in bwd benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Oct 29, 2024
1 parent e33316d commit 07a7f33
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 54 deletions.
5 changes: 4 additions & 1 deletion benchmarks/python/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,15 @@ def norm_bwd_baseline_benchmark(
grads = grads.to(memory_format=torch.channels_last)

norm_fwd_fn = batchnorm_fwd_fn if norm == "batch_norm" else instancenorm_fwd_fn

# Compile the fwd fn for torchcompile
norm_fwd_fn = torch.compile(norm_fwd_fn) if compile else norm_fwd_fn
output = norm_fwd_fn([inputs, weight, bias, running_mean, running_var])

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=norm_bwd_iobytes(size, dtype, norm),
)
22 changes: 14 additions & 8 deletions benchmarks/python/test_dropout_layernorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,23 @@ def test_dropout_layernorm_bwd_baseline_benchmark(
grads = torch.randn(size, device="cuda", dtype=dtype)
weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)
bias = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)

output = torch.nn.functional.layer_norm(
input2 + torch.nn.functional.dropout(input1, p=dropout_p),
normalized_shape=input1.shape[1:],
weight=weights,
bias=bias,
)

def dropout_layernorm_fwd():
return torch.nn.functional.layer_norm(
input2 + torch.nn.functional.dropout(input1, p=dropout_p),
normalized_shape=input1.shape[1:],
weight=weights,
bias=bias,
)

# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(dropout_layernorm_fwd) if compile else dropout_layernorm_fwd
output = fwd_fn()

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=dropout_layernorm_bwd_iobytes(size, dtype),
)
11 changes: 8 additions & 3 deletions benchmarks/python/test_dropout_rmsnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,17 @@ def test_dropout_rmsnorm_bwd_baseline_benchmark(
grads = torch.randn(size, device="cuda", dtype=dtype)
weights = torch.randn(size[1], device="cuda", dtype=dtype)

x = input2 + torch.nn.functional.dropout(input1, p=dropout_p)
output = weights * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5)
def dropout_rmsnorm_fwd():
x = input2 + torch.nn.functional.dropout(input1, p=dropout_p)
output = weights * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5)
return output

fwd_fn = torch.compile(dropout_rmsnorm_fwd) else dropout_rmsnorm_fwd
output = fwd_fn()

run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=dropout_rmsnorm_bwd_iobytes(size, dtype),
)
9 changes: 7 additions & 2 deletions benchmarks/python/test_gelu_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,15 @@ def test_gelu_bwd_baseline_benchmark(
inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True)
bias = torch.ones(size[-1], device="cuda", dtype=dtype)
grads = torch.randn(size, device="cuda", dtype=dtype)
eager_output = torch.nn.functional.gelu(inputs + bias, approximate="tanh")

def gelu_fwd():
return torch.nn.functional.gelu(inputs + bias, approximate="tanh")
fwd_fn = torch.compile(gelu_fwd) if compile else gelu_fwd
eager_output = fwd_fn()

run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[eager_output, grads],
iobytes=gelu_bwd_iobytes(size, dtype),
)
16 changes: 11 additions & 5 deletions benchmarks/python/test_huggingface_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,22 @@ def test_huggingface_attn_bwd_baseline_benchmark(
attention_mask = torch.zeros(
batch_size, nh, seq_len, seq_len, device="cuda", dtype=dtype
)
attn = (inputs + attention_mask).view(batch_size * nh, seq_len, seq_len)
attn = torch.nn.functional.softmax(attn, dim=-1)
output = torch.nn.functional.dropout(attn, p=dropout_p)


def huggingface_attn_fwd():
attn = (inputs + attention_mask).view(batch_size * nh, seq_len, seq_len)
attn = torch.nn.functional.softmax(attn, dim=-1)
output = torch.nn.functional.dropout(attn, p=dropout_p)
return output

# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(huggingface_attn_fwd) if compile else huggingface_attn_fwd
output = fwd_fn()
grads = torch.randn(batch_size * nh, seq_len, seq_len, device="cuda", dtype=dtype)

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=huggingface_attn_bwd_iobytes(size, dtype),
)
19 changes: 11 additions & 8 deletions benchmarks/python/test_layernorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,20 @@ def test_layernorm_bwd_baseline_benchmark(
weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)
bias = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)

output = torch.nn.functional.layer_norm(
inputs,
inputs.shape[1:],
weight=weights,
bias=bias,
)

def layernorm_fwd():
return torch.nn.functional.layer_norm(
inputs,
inputs.shape[1:],
weight=weights,
bias=bias,
)
fwd_fn = torch.compile(layernorm_fwd) if compile else layernorm_fwd
output = fwd_fn()

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=layernorm_bwd_iobytes(size, dtype),
)
21 changes: 13 additions & 8 deletions benchmarks/python/test_nanogpt_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,24 @@ def test_nanogpt_attn_bwd_baseline_benchmark(
1, 1, seq_len, seq_len
)

# Compute output
hs = n_embd // nh
attn = inputs / (hs**0.5)
attn = attn.masked_fill(bias[:, :, :seq_len, :seq_len] == 0, float("-inf"))
attn = torch.nn.functional.softmax(attn, dim=-1)
output = torch.nn.functional.dropout(attn, p=dropout_p)

def nanogpt_attn_fwd():
# Compute output
hs = n_embd // nh
attn = inputs / (hs**0.5)
attn = attn.masked_fill(bias[:, :, :seq_len, :seq_len] == 0, float("-inf"))
attn = torch.nn.functional.softmax(attn, dim=-1)
output = torch.nn.functional.dropout(attn, p=dropout_p)
return output

# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(nanogpt_attn_fwd) if compile else nanogpt_attn_fwd
output = fwd_fn()
grads = torch.randn(batch_size, nh, seq_len, seq_len, device="cuda", dtype=dtype)

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=nanogpt_attn_bwd_iobytes(size, dtype),
)
14 changes: 10 additions & 4 deletions benchmarks/python/test_rmsnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,20 @@ def test_rmsnorm_bwd_baseline_benchmark(
grads = torch.randn(size, device="cuda", dtype=dtype)
weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)

squared_mean = (inputs**2).mean(1, keepdim=True)
rms_eps = torch.sqrt(squared_mean + 1e-5)
output = weights * (inputs / rms_eps)
def rmsnorm_fwd():
squared_mean = (inputs**2).mean(1, keepdim=True)
rms_eps = torch.sqrt(squared_mean + 1e-5)
output = weights * (inputs / rms_eps)
return output

# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(rmsnorm_fwd) if compile else rmsnorm_fwd
output = fwd_fn()

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[output, grads],
iobytes=rmsnorm_bwd_iobytes(size, dtype),
)
11 changes: 8 additions & 3 deletions benchmarks/python/test_scale_bias_relu_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,16 @@ def test_sbr_bwd_baseline_benchmark(
grads = torch.randn(*size, device="cuda", dtype=dtype)
scale = torch.ones(size[-1], device="cuda", dtype=dtype)
bias = torch.ones(size[-1], device="cuda", dtype=dtype)
eager_output = torch.nn.functional.relu(inputs * scale + bias)


def sbr_fwd():
return torch.nn.functional.relu(inputs * scale + bias)
# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(sbr_fwd) if compile else sbr_fwd
eager_output = sbr_fwd()

run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[eager_output, grads],
iobytes=sbr_bwd_iobytes(size, dtype),
)
11 changes: 8 additions & 3 deletions benchmarks/python/test_silu_mul_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,16 @@ def test_silu_mul_bwd_baseline_benchmark(
x = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True)
y = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True)
grads = torch.randn(*size, device="cuda", dtype=dtype)
eager_output = torch.nn.functional.silu(x) * y


def silu_mul_fwd():
return torch.nn.functional.silu(x) * y
# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(silu_mul_fwd) if compile else silu_mul_fwd
eager_output = fwd_fn()

run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
unary_bwd_torch,
[eager_output, grads],
iobytes=silu_mul_bwd_iobytes(size, dtype),
)
18 changes: 9 additions & 9 deletions benchmarks/python/test_softmax_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache
from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES
import numpy as np
Expand Down Expand Up @@ -58,11 +58,6 @@ def softmax_bwd_fusion(
fd.add_output(T19)


def unary_bwd_torch(inputs: list): # [in_tensor, output, grads]
inputs[1].backward(inputs[2], retain_graph=True)
return inputs[0].grad


def softmax_bwd_iobytes(size: tuple, dtype: torch.dtype):
# Total IO bytes = output + grad_out + grad_input
return int(np.prod(size) * dtype.itemsize * 3)
Expand Down Expand Up @@ -111,10 +106,15 @@ def test_softmax_bwd_baseline_benchmark(
clear_dynamo_cache()
input = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True)
grads = torch.randn(size, device="cuda", dtype=dtype)
output = torch.nn.functional.softmax(input, dim=reduction_axis)

def softmax_fwd():
return torch.nn.functional.softmax(input, dim=reduction_axis)
fwd_fn = torch.compile(softmax_fwd) if compile else softmax_fwd
output = fwd_fn()

run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
[input, output, grads],
unary_bwd_torch,
[output, grads],
iobytes=softmax_bwd_iobytes(size, dtype),
)

0 comments on commit 07a7f33

Please sign in to comment.