Skip to content

Commit

Permalink
Rename LogitsProcessor to StepProcessor (facebookresearch#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Nov 6, 2023
1 parent 21e7a44 commit fa7f9c4
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/fairseq2/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from fairseq2.generation.beam_search import BeamSearch as BeamSearch
from fairseq2.generation.beam_search import StandardBeamSearch as StandardBeamSearch
from fairseq2.generation.logits_processor import (
BannedSequenceProcessor as BannedSequenceProcessor,
)
from fairseq2.generation.logits_processor import LogitsProcessor as LogitsProcessor
from fairseq2.generation.sequence_generator import Hypothesis as Hypothesis
from fairseq2.generation.sequence_generator import Seq2SeqGenerator as Seq2SeqGenerator
from fairseq2.generation.sequence_generator import (
Expand All @@ -18,6 +14,10 @@
from fairseq2.generation.sequence_generator import (
SequenceGeneratorOutput as SequenceGeneratorOutput,
)
from fairseq2.generation.step_processor import (
BannedSequenceProcessor as BannedSequenceProcessor,
)
from fairseq2.generation.step_processor import StepProcessor as StepProcessor
from fairseq2.generation.text import SequenceToTextGenerator as SequenceToTextGenerator
from fairseq2.generation.text import SequenceToTextOutput as SequenceToTextOutput
from fairseq2.generation.text import TextTranslator as TextTranslator
14 changes: 7 additions & 7 deletions src/fairseq2/generation/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from fairseq2.data import VocabularyInfo
from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch
from fairseq2.generation.logits_processor import LogitsProcessor
from fairseq2.generation.step_processor import StepProcessor
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.ops import repeat_interleave
Expand Down Expand Up @@ -57,8 +57,8 @@ class SequenceGeneratorOptions:
search: Optional[BeamSearch] = None
"""The beam search algorithm to use."""

logits_processor: Optional[LogitsProcessor] = None
"""Logits processor called before applying beam search step."""
step_processor: Optional[StepProcessor] = None
"""The processor called at each generation step."""


class Seq2SeqGenerator:
Expand All @@ -73,7 +73,7 @@ class Seq2SeqGenerator:
prefix_seq: Union[int, Tensor]
prefix_seq_len: int
search: BeamSearch
logits_processor: Optional[LogitsProcessor]
step_processor: Optional[StepProcessor]

def __init__(
self,
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
# Set beam search.
self.search = self.opts.search or StandardBeamSearch()

self.logits_processor = self.opts.logits_processor
self.step_processor = self.opts.step_processor

@torch.inference_mode()
def __call__(
Expand Down Expand Up @@ -274,8 +274,8 @@ def __call__(
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty

# Update `lprobs` in-place if requested.
if self.logits_processor is not None:
self.logits_processor(
if self.step_processor is not None:
self.step_processor(
seqs[:, : step_nr + 1], lprobs.squeeze(1), lprob=True
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fairseq2.typing import finaloverride


class LogitsProcessor(ABC):
class StepProcessor(ABC):
"""Processes next-step probabilities during sequence generation."""

@abstractmethod
Expand All @@ -35,7 +35,7 @@ def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:


@final
class BannedSequenceProcessor(LogitsProcessor):
class BannedSequenceProcessor(StepProcessor):
"""Prevents a provided list of banned sequences from being generated."""

_banned_seqs: Optional[Tensor]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_call_works(
banned_seqs = [text_encoder(b) for b in banned_words]

opts = SequenceGeneratorOptions(
logits_processor=BannedSequenceProcessor(banned_seqs)
step_processor=BannedSequenceProcessor(banned_seqs)
)

translator = TextTranslator(
Expand Down

0 comments on commit fa7f9c4

Please sign in to comment.