Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 16, 2024
1 parent d9f06f3 commit dc2211b
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions benchmarks/python/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import pytest
from nvfuser import FusionDefinition, DataType
from .core import run_benchmark, with_executor
from .core import run_benchmark, with_executor, unary_bwd_torch
import torch

from functools import partial
Expand Down Expand Up @@ -245,11 +245,23 @@ def forward(self, qkv, cos, sin):
configs["llama_2_7b_hf_rope"] = Config(n_head=32, head_size=128, n_query_groups=32, rope_n_elem=128, batches=2, seq_length=4096)
configs["llama_3_8B_rope"] = Config(n_head=32, head_size=128, n_query_groups=8, rope_n_elem=128, batches=2, seq_length=8192)

return LitGPTRope(configs[config_str]).cuda().bfloat16()
cfg = configs[config_str]

def inputs():
qkv = torch.randn(cfg.batches, cfg.seq_length, cfg.head_size * (cfg.n_head + 2 * cfg.n_query_groups), device='cuda', dtype=torch.bfloat16, requires_grad=True)
cos = torch.randn(cfg.seq_length, cfg.rope_n_elem, device='cuda', dtype=torch.bfloat16, requires_grad=False)
sin = torch.randn(cfg.seq_length, cfg.rope_n_elem, device='cuda', dtype=torch.bfloat16, requires_grad=False)
return qkv, cos, sin

def grads():
grad = torch.randn(cfg.batches, cfg.n_head, cfg.seq_length, cfg.head_size, device='cuda', dtype=torch.bfloat16, requires_grad=False)
return grad

return LitGPTRope(cfg).cuda().bfloat16(), inputs, grads


# { 'name_benchmark' : (fn, [[sizes0, optional_strides0, dtype0], [sizes1, dtype1], ...]) }
rope_configurations = {
rope_setup = {
"llama_2_7b_hf_rope": (
partial(llama_hf_rope, config_str="llama_2_7b_hf_rope"),
[
Expand Down Expand Up @@ -316,21 +328,15 @@ def test_rope_variations_fwd_benchmark(
if executor == "torchcompile":
clear_dynamo_cache()

config = rope_configurations[rope_variation]
config = rope_setup[rope_variation]

inputs = []
for entry in config[1]:
tensor = torch.testing.make_tensor(entry[0], dtype=entry[-1], device="cuda:0")
inputs.append(
tensor if len(entry) == 2 else tensor.as_strided(entry[0], entry[1])
)
model = config[0]()
model, inputs, _ = config[0]()

def fwd_call(inp):
return model(*inp)

benchmark_fn = with_executor(executor, fwd_call)
run_benchmark(benchmark, benchmark_fn, inputs)
run_benchmark(benchmark, benchmark_fn, inputs())


@pytest.mark.parametrize(
Expand All @@ -352,22 +358,21 @@ def test_rope_variations_bwd_benchmark(
if executor == "torchcompile":
clear_dynamo_cache()

config = rope_configurations[rope_variation]

fwd_inputs = []
for entry in config[1]:
tensor = torch.testing.make_tensor(entry[0], dtype=entry[-1], device="cuda:0")
fwd_inputs.append(
tensor if len(entry) == 2 else tensor.as_strided(entry[0], entry[1])
)
model = config[0]()
config = rope_setup[rope_variation]
model, fwd_inputs, grad = config[0]()

def fwd_call(inp):
return model(*inp)

# execute the compiled fwd fn
fwd_fn = with_executor(executor, fwd_call)
ouptuts = fwd_fn(fwd_inputs)
outputs = fwd_fn(fwd_inputs())

output = outputs[0]
for i in range(1, len(outputs)):
output += outputs[i]

print(f"{output.shape=}")
print(f"{grad().shape=}")
benchmark_fn = with_executor(executor, fwd_call)
run_benchmark(benchmark, unary_bwd_torch, [outputs, grads])
run_benchmark(benchmark, unary_bwd_torch, [output, grad()])

0 comments on commit dc2211b

Please sign in to comment.