Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 16, 2024
1 parent dc2211b commit 8882d06
Showing 1 changed file with 6 additions and 21 deletions.
27 changes: 6 additions & 21 deletions benchmarks/python/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,10 @@ def grads():

# { 'name_benchmark' : (fn, [[sizes0, optional_strides0, dtype0], [sizes1, dtype1], ...]) }
rope_setup = {
"llama_2_7b_hf_rope": (
"llama_2_7b_hf_rope":
partial(llama_hf_rope, config_str="llama_2_7b_hf_rope"),
[
((2, 4096, 12288), torch.bfloat16),
((4096, 128), torch.bfloat16),
((4096, 128), torch.bfloat16),
],
),
"llama_3_8B_rope": (
"llama_3_8B_rope":
partial(llama_hf_rope, config_str="llama_3_8B_rope"),
[
((2, 8192, 6144), torch.bfloat16),
((8192, 128), torch.bfloat16),
((8192, 128), torch.bfloat16),
],
),
# "hf_qwen2_rope": (
# hf_qwen2_rope,
# [
Expand Down Expand Up @@ -328,9 +316,7 @@ def test_rope_variations_fwd_benchmark(
if executor == "torchcompile":
clear_dynamo_cache()

config = rope_setup[rope_variation]

model, inputs, _ = config[0]()
model, inputs, _ = rope_setup[rope_variation]

def fwd_call(inp):
return model(*inp)
Expand Down Expand Up @@ -358,8 +344,8 @@ def test_rope_variations_bwd_benchmark(
if executor == "torchcompile":
clear_dynamo_cache()

config = rope_setup[rope_variation]
model, fwd_inputs, grad = config[0]()
# TODO why not just a random like for grad on output instead of returning a grad function
model, fwd_inputs, grad = rope_setup[rope_variation]

def fwd_call(inp):
return model(*inp)
Expand All @@ -368,11 +354,10 @@ def fwd_call(inp):
fwd_fn = with_executor(executor, fwd_call)
outputs = fwd_fn(fwd_inputs())

# NOTE does this look about right?
output = outputs[0]
for i in range(1, len(outputs)):
output += outputs[i]

print(f"{output.shape=}")
print(f"{grad().shape=}")
benchmark_fn = with_executor(executor, fwd_call)
run_benchmark(benchmark, unary_bwd_torch, [output, grad()])

0 comments on commit 8882d06

Please sign in to comment.