diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 5b97f03257502..619263225333d 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -1,6 +1,7 @@ # noqa: UP007 from __future__ import annotations +import copy import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -238,7 +239,6 @@ class XGrammarLogitsProcessor: token_bitmask: torch.Tensor = None # type: ignore[assignment] matchers: list[xgr.GrammarMatcher] = field(default_factory=list) batch_size: int = field(default=1) - prefilled: bool = field(default=False) def __getstate__(self) -> dict[str, Any]: return {'config': self.config} @@ -250,7 +250,6 @@ def __setstate__(self, state: dict[str, Any]): self.matchers = [] self.batch_size = 1 self.token_bitmask = None # type: ignore[assignment] - self.prefilled = False def _ensure_ctx(self): """Lazily initialize the processor in the worker process""" @@ -278,10 +277,7 @@ def __call__(self, input_ids: list[int], self.token_bitmask = xgr.allocate_token_bitmask( self.batch_size, self.config.vocab_size) - if not self.prefilled: - # Have not sampled a token yet - self.prefilled = True - else: + if len(input_ids) > 0: for i, matcher in enumerate(self.matchers): if not matcher.is_terminated(): sampled_token = input_ids[-1] @@ -309,3 +305,7 @@ def __call__(self, input_ids: list[int], scores = scores.to(device_type).squeeze() return scores + + def clone(self) -> XGrammarLogitsProcessor: + """Deepcopy due to per-sequence state in the matchers""" + return copy.deepcopy(self) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index fc77f3ca529b2..605c09b8d7225 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -450,15 +450,16 @@ def all_stop_token_ids(self) -> Set[int]: return self._all_stop_token_ids def clone(self) -> "SamplingParams": - """Deep copy excluding LogitsProcessor objects. + """Deep copy, but maybe not the LogitsProcessor objects. - LogitsProcessor objects are excluded because they may contain an - arbitrary, nontrivial amount of data. + LogitsProcessor objects may contain an arbitrary, nontrivial amount of + data that is expensive to copy. However, if not copied, the processor + needs to support parallel decoding for multiple sequences See https://github.com/vllm-project/vllm/issues/3087 """ logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp + id(lp): lp.clone() if hasattr(lp, 'clone') else lp for lp in self.logits_processors } return copy.deepcopy(self, memo=logit_processor_refs) diff --git a/vllm/sequence.py b/vllm/sequence.py index cc3d96fc93a79..15fe12aacb118 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1366,9 +1366,9 @@ def maybe_assemble_group( class ParallelSampleSequenceGroup(SequenceGroupBase): @staticmethod - def add_request(request_id: str, engine, params, **kwargs): + def add_request(request_id: str, engine, params: Union[SamplingParams, PoolingParams], **kwargs): original_params = params - params = copy.deepcopy(original_params) + params = original_params.clone() params.n = 1 group = ParallelSampleSequenceGroup(request_id) seqs = []