Skip to content

Commit

Permalink
[Operator] update tiling config
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Jan 6, 2025
1 parent 5793cd2 commit bf4835e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/flag_gems/ops/nllloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@libentry()
@triton.autotune(
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [1, 16, 256]],
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [1, 4, 32, 128]],
key=["N"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand Down Expand Up @@ -92,7 +92,7 @@ def nll_loss_2d_bwd_kernel(

@libentry()
@triton.autotune(
configs=[triton.Config({"BLOCK_D": d}, num_warps=4) for d in [1, 4, 16]],
configs=[triton.Config({"BLOCK_D": d}, num_warps=4) for d in [8, 32, 128]],
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand Down

0 comments on commit bf4835e

Please sign in to comment.