From 8882d066b1f3e5304b4f6ac9a72f1c2940421ac4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 15 Dec 2024 20:18:46 -0800 Subject: [PATCH] WIP --- benchmarks/python/test_rope.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/benchmarks/python/test_rope.py b/benchmarks/python/test_rope.py index 8dcb4ac3585..70294dd8144 100644 --- a/benchmarks/python/test_rope.py +++ b/benchmarks/python/test_rope.py @@ -262,22 +262,10 @@ def grads(): # { 'name_benchmark' : (fn, [[sizes0, optional_strides0, dtype0], [sizes1, dtype1], ...]) } rope_setup = { - "llama_2_7b_hf_rope": ( + "llama_2_7b_hf_rope": partial(llama_hf_rope, config_str="llama_2_7b_hf_rope"), - [ - ((2, 4096, 12288), torch.bfloat16), - ((4096, 128), torch.bfloat16), - ((4096, 128), torch.bfloat16), - ], - ), - "llama_3_8B_rope": ( + "llama_3_8B_rope": partial(llama_hf_rope, config_str="llama_3_8B_rope"), - [ - ((2, 8192, 6144), torch.bfloat16), - ((8192, 128), torch.bfloat16), - ((8192, 128), torch.bfloat16), - ], - ), # "hf_qwen2_rope": ( # hf_qwen2_rope, # [ @@ -328,9 +316,7 @@ def test_rope_variations_fwd_benchmark( if executor == "torchcompile": clear_dynamo_cache() - config = rope_setup[rope_variation] - - model, inputs, _ = config[0]() + model, inputs, _ = rope_setup[rope_variation] def fwd_call(inp): return model(*inp) @@ -358,8 +344,8 @@ def test_rope_variations_bwd_benchmark( if executor == "torchcompile": clear_dynamo_cache() - config = rope_setup[rope_variation] - model, fwd_inputs, grad = config[0]() + # TODO why not just a random like for grad on output instead of returning a grad function + model, fwd_inputs, grad = rope_setup[rope_variation] def fwd_call(inp): return model(*inp) @@ -368,11 +354,10 @@ def fwd_call(inp): fwd_fn = with_executor(executor, fwd_call) outputs = fwd_fn(fwd_inputs()) + # NOTE does this look about right? output = outputs[0] for i in range(1, len(outputs)): output += outputs[i] - print(f"{output.shape=}") - print(f"{grad().shape=}") benchmark_fn = with_executor(executor, fwd_call) run_benchmark(benchmark, unary_bwd_torch, [output, grad()])