diff --git a/vllm_hpu_extension/bucketing.py b/vllm_hpu_extension/bucketing.py index 8a2fe2d9..5d52bbeb 100644 --- a/vllm_hpu_extension/bucketing.py +++ b/vllm_hpu_extension/bucketing.py @@ -257,6 +257,10 @@ def round_up(value: int, k: int) -> int: def find_bucket(value: int, config: Tuple[int, int, int]) -> int: bmin, bstep, _ = config - next_step = round_up(value, bstep) - next_pow = next_pow2(value, bmin) - return max(bmin, min(next_step, next_pow)) + if value <= bmin: + return bmin + else: + next_step = round_up(value, bstep) + next_pow = next_pow2(value, bmin) + return min(next_step, next_pow) +