From 069808be62966ab10c61d81873c3d0e1d6825901 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 28 May 2024 18:06:55 +0000 Subject: [PATCH 1/3] Return logits from single step --- vllm/engine/output_processor/single_step.py | 6 ++-- vllm/model_executor/layers/sampler.py | 33 ++++++++------------- vllm/outputs.py | 20 ++++++++----- vllm/sequence.py | 21 +++++++++++-- 4 files changed, 48 insertions(+), 32 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 07b140584bbe2..d29a6880ff20e 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -102,14 +102,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, - child_sample.logprobs) + child_sample.logprobs, + child_sample.output_logits) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + last_child_sample.logprobs, + last_child_sample.output_logits) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d2053d996..912d212a44349 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -110,7 +110,8 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, + return _build_sampler_output(logits, + sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, @@ -970,6 +971,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( + logits, sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], @@ -986,29 +988,20 @@ def _build_sampler_output( speculative decoding rejection sampling. """ + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) + sampler_output = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) + for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices): + seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, logits[sample_idx])) + sampler_output.append(SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) return SamplerOutput( outputs=sampler_output, diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..725d90cc72222 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -30,6 +30,7 @@ def __init__( text: str, token_ids: List[int], cumulative_logprob: float, + logits: List[float], logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, stop_reason: Union[int, str, None] = None, @@ -43,6 +44,7 @@ def __init__( self.finish_reason = finish_reason self.stop_reason = stop_reason self.lora_request = lora_request + self.logits = logits def finished(self) -> bool: return self.finish_reason is not None @@ -52,6 +54,7 @@ def __repr__(self) -> str: f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " + f"logits={self.logits[0][:3]} ... {self.logits[0][-3:]}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " f"stop_reason={self.stop_reason})") @@ -114,13 +117,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": include_logprobs = seq_group.sampling_params.logprobs is not None text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs + CompletionOutput( + seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + seq.get_output_logits(), + seq.output_logprobs if include_logprobs else None, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason, + ) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sequence.py b/vllm/sequence.py index f2939eff7959b..2ffe0ffe4f1f3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,6 +113,7 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], + output_logits: Optional[List[float]] = None, output_token_ids: Optional[List[int]] = None, ) -> None: if output_token_ids is None: @@ -120,13 +121,15 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids + self.output_logits = output_logits or [] self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, token_id: int, logprob: float, logits: List[float]) -> None: self.output_token_ids.append(token_id) + self.output_logits.append(logits) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -135,6 +138,9 @@ def get_len(self) -> int: def get_prompt_len(self) -> int: return len(self.prompt_token_ids) + def get_output_logits(self) -> List[float]: + return self.output_logits + def get_output_len(self) -> int: return len(self.output_token_ids) @@ -219,6 +225,7 @@ def __init__( self.lora_request = lora_request self.data: SequenceData = SequenceData(prompt_token_ids) + self.output_logits = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -288,11 +295,13 @@ def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], + logits: List[float], ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.output_logits.append(logits) + self.data.append_token_id(token_id, logprobs[token_id].logprob, logits) def get_len(self) -> int: return self.data.get_len() @@ -303,6 +312,9 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() + def get_output_logits(self) -> List[float]: + return self.data.get_output_logits() + def get_token_ids(self) -> List[int]: return self.data.get_token_ids() @@ -644,13 +656,16 @@ def __init__( parent_seq_id: int, output_token: int, logprobs: Dict[int, Logprob], + logits ) -> None: self.parent_seq_id = parent_seq_id + self.output_logits = logits.tolist() self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_logits_len={len(self.output_logits)}, " f"output_token={self.output_token}, " f"logprobs={self.logprobs})") @@ -677,7 +692,7 @@ def __init__( def __repr__(self) -> str: return (f"SequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") + f"prompt_logprobs={self.prompt_logprobs}),") def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceGroupOutput): From 6883da539977d0290ceefc7816736584f420d1d9 Mon Sep 17 00:00:00 2001 From: Joe G Date: Tue, 28 May 2024 20:18:37 -0700 Subject: [PATCH 2/3] Call classification head inside sampler --- vllm/engine/output_processor/single_step.py | 4 +-- vllm/model_executor/layers/sampler.py | 13 ++++++--- vllm/outputs.py | 8 +++--- vllm/sequence.py | 30 ++++++++++----------- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index d29a6880ff20e..e8b08d8c2ee88 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -103,7 +103,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, child_sample.logprobs, - child_sample.output_logits) + child_sample.output_classification_probs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory @@ -111,7 +111,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, last_child_sample = child_samples[-1] parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs, - last_child_sample.output_logits) + last_child_sample.output_classification_probs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 912d212a44349..c5e1ae88ccb18 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -46,6 +46,9 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda") + self.classification_head.weight.data = torch.load("classification_head.pth", map_location="cuda").bfloat16() + def forward( self, logits: torch.Tensor, @@ -61,6 +64,10 @@ def forward( logits = _apply_min_tokens_penalty(logits, sampling_metadata) + classification_probs = torch.nn.functional.sigmoid( + self.classification_head(logits) + ).flatten().tolist() + # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -110,7 +117,7 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(logits, + return _build_sampler_output(classification_probs, sample_results, sampling_metadata, prompt_logprobs, @@ -971,7 +978,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - logits, + classification_probs, sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], @@ -1000,7 +1007,7 @@ def _build_sampler_output( next_token_ids, parent_ids = sample_result seq_outputs = [] for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices): - seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, logits[sample_idx])) + seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, classification_probs[sample_idx])) sampler_output.append(SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) return SamplerOutput( diff --git a/vllm/outputs.py b/vllm/outputs.py index 725d90cc72222..2878abec68f3a 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -30,7 +30,7 @@ def __init__( text: str, token_ids: List[int], cumulative_logprob: float, - logits: List[float], + classification_probs: List[float], logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, stop_reason: Union[int, str, None] = None, @@ -44,7 +44,7 @@ def __init__( self.finish_reason = finish_reason self.stop_reason = stop_reason self.lora_request = lora_request - self.logits = logits + self.classification_probs = classification_probs def finished(self) -> bool: return self.finish_reason is not None @@ -54,7 +54,7 @@ def __repr__(self) -> str: f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " - f"logits={self.logits[0][:3]} ... {self.logits[0][-3:]}, " + f"classification_probs={self.classification_probs}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " f"stop_reason={self.stop_reason})") @@ -122,7 +122,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), - seq.get_output_logits(), + seq.get_output_classification_probs(), seq.output_logprobs if include_logprobs else None, SequenceStatus.get_finished_reason(seq.status), seq.stop_reason, diff --git a/vllm/sequence.py b/vllm/sequence.py index 2ffe0ffe4f1f3..5b74461bcc300 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,7 +113,7 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], - output_logits: Optional[List[float]] = None, + output_classification_probs: Optional[List[float]] = None, output_token_ids: Optional[List[int]] = None, ) -> None: if output_token_ids is None: @@ -121,15 +121,15 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids - self.output_logits = output_logits or [] + self.output_classification_probs = output_classification_probs or [] self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL - def append_token_id(self, token_id: int, logprob: float, logits: List[float]) -> None: + def append_token_id(self, token_id: int, logprob: float, classification_probs: List[float]) -> None: self.output_token_ids.append(token_id) - self.output_logits.append(logits) + self.output_classification_probs.append(classification_probs) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -138,8 +138,8 @@ def get_len(self) -> int: def get_prompt_len(self) -> int: return len(self.prompt_token_ids) - def get_output_logits(self) -> List[float]: - return self.output_logits + def get_output_classification_probs(self) -> List[float]: + return self.output_classification_probs def get_output_len(self) -> int: return len(self.output_token_ids) @@ -225,7 +225,7 @@ def __init__( self.lora_request = lora_request self.data: SequenceData = SequenceData(prompt_token_ids) - self.output_logits = [] + self.output_classification_probs = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -295,13 +295,13 @@ def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], - logits: List[float], + classification_probs: List[float], ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.output_logits.append(logits) - self.data.append_token_id(token_id, logprobs[token_id].logprob, logits) + self.output_classification_probs.append(classification_probs) + self.data.append_token_id(token_id, logprobs[token_id].logprob, classification_probs) def get_len(self) -> int: return self.data.get_len() @@ -312,8 +312,8 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() - def get_output_logits(self) -> List[float]: - return self.data.get_output_logits() + def get_output_classification_probs(self) -> List[float]: + return self.data.get_output_classification_probs() def get_token_ids(self) -> List[int]: return self.data.get_token_ids() @@ -656,16 +656,16 @@ def __init__( parent_seq_id: int, output_token: int, logprobs: Dict[int, Logprob], - logits + classification_probs: List[float] ) -> None: self.parent_seq_id = parent_seq_id - self.output_logits = logits.tolist() + self.output_classification_probs = classification_probs self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_logits_len={len(self.output_logits)}, " + f"output_classification_probs={self.output_classification_probs}, " f"output_token={self.output_token}, " f"logprobs={self.logprobs})") From 9138c32683c20ba10790b66996ce52b7009b90c1 Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 30 May 2024 12:01:04 -0700 Subject: [PATCH 3/3] Fix formatting --- vllm/engine/output_processor/single_step.py | 6 ++-- vllm/model_executor/layers/sampler.py | 25 +++++++++++----- vllm/sequence.py | 33 ++++++++++----------- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index adb39eedf72e4..779d8609ec9fc 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -109,9 +109,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs, - last_child_sample.output_classification_probs) + parent.append_token_id( + last_child_sample.output_token, last_child_sample.logprobs, + last_child_sample.output_classification_probs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1d1f8de6276aa..c5f99e314371e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -48,7 +48,8 @@ def __init__(self): self.include_gpu_probs_tensor = False self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda") - self.classification_head.weight.data = torch.load("classification_head.pth", map_location="cuda").bfloat16() + self.classification_head.weight.data = torch.load( + "classification_head.pth", map_location="cuda").bfloat16() def forward( self, @@ -66,8 +67,7 @@ def forward( logits = _apply_min_tokens_penalty(logits, sampling_metadata) classification_probs = torch.nn.functional.sigmoid( - self.classification_head(logits) - ).flatten().tolist() + self.classification_head(logits)).flatten().tolist() # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, @@ -1018,20 +1018,29 @@ def _build_sampler_output( """ sampler_output = [] - for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] - for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices): - seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, classification_probs[sample_idx])) - sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + for parent_id, next_token_id, logprobs, sample_idx in zip( + parent_ids, next_token_ids, group_sample_logprobs, + seq_group.sample_indices): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, + classification_probs[sample_idx])) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: (sampled_token_probs, logprobs_tensor, sampled_token_ids) = on_device_tensors else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, diff --git a/vllm/sequence.py b/vllm/sequence.py index 4925adecde51f..a8a3a87413107 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -131,9 +131,10 @@ def __init__( self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL - def append_token_id(self, token_id: int, logprob: float, classification_probs: List[float]) -> None: + def append_token_id(self, token_id: int, logprob: float, + classification_prob: float) -> None: self.output_token_ids.append(token_id) - self.output_classification_probs.append(classification_probs) + self.output_classification_probs.append(classification_prob) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -237,7 +238,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.data = SequenceData(self.prompt_token_ids) - self.output_classification_probs = [] + self.output_classification_probs: List[float] = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -319,13 +320,14 @@ def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], - classification_probs: List[float], + classification_prob: float, ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.output_classification_probs.append(classification_probs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, classification_probs) + self.output_classification_probs.append(classification_prob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + classification_prob) def get_len(self) -> int: return self.data.get_len() @@ -719,23 +721,20 @@ class SequenceOutput: (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ - def __init__( - self, - parent_seq_id: int, - output_token: int, - logprobs: Dict[int, Logprob], - classification_probs: List[float] - ) -> None: + def __init__(self, parent_seq_id: int, output_token: int, + logprobs: Dict[int, Logprob], + classification_probs: List[float]) -> None: self.parent_seq_id = parent_seq_id self.output_classification_probs = classification_probs self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_classification_probs={self.output_classification_probs}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})") + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_classification_probs={self.output_classification_probs}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput):