Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce SequenceGeneratorHandler #948

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 63 additions & 40 deletions src/fairseq2/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,45 @@

from __future__ import annotations

from fairseq2.generation.beam_search.algo import (
STANDARD_BEAM_SEARCH_ALGO as STANDARD_BEAM_SEARCH_ALGO,
)
from fairseq2.generation.beam_search.algo import (
BeamSearchAlgorithm as BeamSearchAlgorithm,
)
from fairseq2.generation.beam_search.algo import BeamStep as BeamStep
from fairseq2.generation.beam_search.algo import (
StandardBeamSearchAlgorithm as StandardBeamSearchAlgorithm,
BeamSearchAlgorithmHandler as BeamSearchAlgorithmHandler,
)
from fairseq2.generation.beam_search.factory import (
StandardBeamSearchConfig as StandardBeamSearchConfig,
from fairseq2.generation.beam_search.algo import (
BeamSearchAlgorithmNotFoundError as BeamSearchAlgorithmNotFoundError,
)
from fairseq2.generation.beam_search.factory import (
beam_search_factories as beam_search_factories,
from fairseq2.generation.beam_search.algo import BeamStep as BeamStep
from fairseq2.generation.beam_search.algo import (
StandardBeamSearchAlgorithm as StandardBeamSearchAlgorithm,
)
from fairseq2.generation.beam_search.factory import (
beam_search_factory as beam_search_factory,
from fairseq2.generation.beam_search.algo import (
StandardBeamSearchAlgorithmHandler as StandardBeamSearchAlgorithmHandler,
)
from fairseq2.generation.beam_search.generator import (
BeamSearchSeq2SeqGenerator as BeamSearchSeq2SeqGenerator,
)
from fairseq2.generation.beam_search.generator import (
BeamSearchSequenceGenerator as BeamSearchSequenceGenerator,
)
from fairseq2.generation.factory import BeamSearchConfig as BeamSearchConfig
from fairseq2.generation.factory import SamplingConfig as SamplingConfig
from fairseq2.generation.factory import (
create_beam_search_seq2seq_generator as create_beam_search_seq2seq_generator,
)
from fairseq2.generation.factory import (
create_beam_search_seq_generator as create_beam_search_seq_generator,
from fairseq2.generation.beam_search.handler import (
BEAM_SEARCH_GENERATOR as BEAM_SEARCH_GENERATOR,
)
from fairseq2.generation.factory import (
create_sampling_seq2seq_generator as create_sampling_seq2seq_generator,
)
from fairseq2.generation.factory import (
create_sampling_seq_generator as create_sampling_seq_generator,
)
from fairseq2.generation.factory import (
create_seq2seq_generator as create_seq2seq_generator,
from fairseq2.generation.beam_search.handler import AlgorithmSection as AlgorithmSection
from fairseq2.generation.beam_search.handler import (
AlgorithmSectionHandler as AlgorithmSectionHandler,
)
from fairseq2.generation.factory import create_seq_generator as create_seq_generator
from fairseq2.generation.factory import (
seq2seq_generator_factories as seq2seq_generator_factories,
from fairseq2.generation.beam_search.handler import BeamSearchConfig as BeamSearchConfig
from fairseq2.generation.beam_search.handler import (
BeamSearchSeq2SeqGeneratorHandler as BeamSearchSeq2SeqGeneratorHandler,
)
from fairseq2.generation.factory import (
seq2seq_generator_factory as seq2seq_generator_factory,
from fairseq2.generation.beam_search.handler import (
BeamSearchSequenceGeneratorHandler as BeamSearchSequenceGeneratorHandler,
)
from fairseq2.generation.factory import (
seq_generator_factories as seq_generator_factories,
)
from fairseq2.generation.factory import seq_generator_factory as seq_generator_factory
from fairseq2.generation.generator import (
AbstractSeq2SeqGenerator as AbstractSeq2SeqGenerator,
)
Expand All @@ -72,25 +61,59 @@
SequenceGeneratorOutput as SequenceGeneratorOutput,
)
from fairseq2.generation.generator import StepHook as StepHook
from fairseq2.generation.sampling.factory import TopKSamplerConfig as TopKSamplerConfig
from fairseq2.generation.sampling.factory import TopPSamplerConfig as TopPSamplerConfig
from fairseq2.generation.sampling.factory import (
create_top_k_sampler as create_top_k_sampler,
from fairseq2.generation.handler import (
Seq2SeqGeneratorHandler as Seq2SeqGeneratorHandler,
)
from fairseq2.generation.handler import (
Seq2SeqGeneratorNotFoundError as Seq2SeqGeneratorNotFoundError,
)
from fairseq2.generation.handler import (
SequenceGeneratorHandler as SequenceGeneratorHandler,
)
from fairseq2.generation.sampling.factory import (
create_top_p_sampler as create_top_p_sampler,
from fairseq2.generation.handler import (
SequenceGeneratorNotFoundError as SequenceGeneratorNotFoundError,
)
from fairseq2.generation.sampling.factory import sampler_factories as sampler_factories
from fairseq2.generation.sampling.factory import sampler_factory as sampler_factory
from fairseq2.generation.sampling.generator import (
SamplingSeq2SeqGenerator as SamplingSeq2SeqGenerator,
)
from fairseq2.generation.sampling.generator import (
SamplingSequenceGenerator as SamplingSequenceGenerator,
)
from fairseq2.generation.sampling.handler import (
SAMPLING_GENERATOR as SAMPLING_GENERATOR,
)
from fairseq2.generation.sampling.handler import SamplerSection as SamplerSection
from fairseq2.generation.sampling.handler import (
SamplerSectionHandler as SamplerSectionHandler,
)
from fairseq2.generation.sampling.handler import SamplingConfig as SamplingConfig
from fairseq2.generation.sampling.handler import (
SamplingSeq2SeqGeneratorHandler as SamplingSeq2SeqGeneratorHandler,
)
from fairseq2.generation.sampling.handler import (
SamplingSequenceGeneratorHandler as SamplingSequenceGeneratorHandler,
)
from fairseq2.generation.sampling.sampler import TOP_K_SAMPLER as TOP_K_SAMPLER
from fairseq2.generation.sampling.sampler import TOP_P_SAMPLER as TOP_P_SAMPLER
from fairseq2.generation.sampling.sampler import Sampler as Sampler
from fairseq2.generation.sampling.sampler import SamplerHandler as SamplerHandler
from fairseq2.generation.sampling.sampler import (
SamplerNotFoundError as SamplerNotFoundError,
)
from fairseq2.generation.sampling.sampler import TopKSampler as TopKSampler
from fairseq2.generation.sampling.sampler import TopKSamplerConfig as TopKSamplerConfig
from fairseq2.generation.sampling.sampler import (
TopKSamplerHandler as TopKSamplerHandler,
)
from fairseq2.generation.sampling.sampler import TopPSampler as TopPSampler
from fairseq2.generation.sampling.sampler import TopPSamplerConfig as TopPSamplerConfig
from fairseq2.generation.sampling.sampler import (
TopPSamplerHandler as TopPSamplerHandler,
)
from fairseq2.generation.static import (
create_seq2seq_generator as create_seq2seq_generator,
)
from fairseq2.generation.static import create_seq_generator as create_seq_generator
from fairseq2.generation.step_processor import (
BannedSequenceProcessor as BannedSequenceProcessor,
)
Expand Down
53 changes: 49 additions & 4 deletions src/fairseq2/generation/beam_search/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Protocol, final
from types import NoneType
from typing import Final, final

import torch
from torch import Tensor
from typing_extensions import override


class BeamSearchAlgorithm(Protocol):
class BeamSearchAlgorithm(ABC):
"""Represents a beam search algorithm."""

def __call__(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep:
@abstractmethod
def step(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep:
"""Take a single step.

A subclass implementation is expected to return the best 2 x `beam_size`
Expand All @@ -41,7 +45,8 @@ def __call__(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamS
class StandardBeamSearchAlgorithm(BeamSearchAlgorithm):
"""Represents a standard beam search algoritm."""

def __call__(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep:
@override
def step(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep:
vocab_size = lprobs.size(1)

# Make the probabilities contain cumulative scores for each hypothesis.
Expand Down Expand Up @@ -102,3 +107,43 @@ def merge(steps: Sequence[BeamStep]) -> BeamStep:
scores = torch.cat([s.scores for s in steps])

return BeamStep(seq_indices, vocab_indices, scores)


class BeamSearchAlgorithmHandler(ABC):
@abstractmethod
def create(self, config: object) -> BeamSearchAlgorithm:
...

@property
@abstractmethod
def config_kls(self) -> type:
...


class BeamSearchAlgorithmNotFoundError(LookupError):
name: str

def __init__(self, name: str) -> None:
super().__init__(f"'{name}' is not a known beam search algorithm.")

self.name = name


STANDARD_BEAM_SEARCH_ALGO: Final = "standard"


@final
class StandardBeamSearchAlgorithmHandler(BeamSearchAlgorithmHandler):
@override
def create(self, config: object) -> BeamSearchAlgorithm:
if config is not None:
raise ValueError(
"`config` must not be specified for standard beam-search algorithm."
)

return StandardBeamSearchAlgorithm()

@property
@override
def config_kls(self) -> type:
return NoneType
29 changes: 0 additions & 29 deletions src/fairseq2/generation/beam_search/factory.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/fairseq2/generation/beam_search/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def _search_beam(
# best 2 x `beam_size` candidates and choose the first `beam_size` of
# these which don't predict EOS to continue with.
# (2 x B)
next_step = self._algorithm(
next_step = self._algorithm.step(
self._beam_size, lprobs, step_scores[:, : self._step_nr]
)

Expand Down
Loading
Loading