Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Jan 2, 2025
1 parent 5b8bdf1 commit 132615c
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions benchmarks/kernels/benchmark_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
})
return {'kwargs_list': kwargs_list}

def bench_fn_kwargs(self, op_type: OpType, add_inputs: Optional[bool] = None) -> Dict[str, Any]:
def bench_fn_kwargs(self,
op_type: OpType,
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
if op_type.is_shrink_fn():
assert add_inputs is None
else:
Expand Down Expand Up @@ -577,16 +579,21 @@ def bench_optype(ctx: BenchmarkContext,
bt.sanity_check()

# BenchmarkTensors -> Dict (kwargs)
kwargs_list = [bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors]
kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors
]

# Merge into a single kwargs and quality arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
for _kwargs in kwargs_list:
for k, v in _kwargs.items():
kwargs[k].values.append(v)

describe_args = f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
describe_args = (f"add_inputs={expand_fn_add_inputs}"
if expand_fn_add_inputs is not None else "")
description = (
f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})")
cuda_graph_params = CudaGraphBenchParams(
num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None
with Bench(cuda_graph_params,
Expand Down Expand Up @@ -666,12 +673,13 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
args.with_cuda_graph))

# Benchmark bench_op
expand_fn_add_inputs = [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
expand_fn_add_inputs = [
None
] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append(
bench_optype(_ctx, args.arg_pool_size, bench_op,
args.with_cuda_graph,
add_input_arg))
args.with_cuda_graph, add_input_arg))

print_timers(seq_len_timers)
timers.extend(seq_len_timers)
Expand Down

0 comments on commit 132615c

Please sign in to comment.