diff --git a/src/fairseq2/generation/__init__.py b/src/fairseq2/generation/__init__.py index ec5a9faba..16b422343 100644 --- a/src/fairseq2/generation/__init__.py +++ b/src/fairseq2/generation/__init__.py @@ -30,9 +30,11 @@ from fairseq2.generation.sampling import Sampler as Sampler from fairseq2.generation.sampling import ( SamplingSeq2SeqGenerator as SamplingSeq2SeqGenerator, + SpeculativeSamplingSeq2SeqGenerator as SpeculativeSamplingSeq2SeqGenerator, ) from fairseq2.generation.sampling import ( SamplingSequenceGenerator as SamplingSequenceGenerator, + SpeculativeSamplingSequenceGenerator as SpeculativeSamplingSequenceGenerator, ) from fairseq2.generation.sampling import TopKSampler as TopKSampler from fairseq2.generation.sampling import TopPSampler as TopPSampler diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 2081e8a89..39f2389aa 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -8,10 +8,12 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Sequence, Tuple, Union, final +import time import torch from torch import Tensor from torch.nn.functional import softmax +import numpy as np from fairseq2.data import VocabularyInfo from fairseq2.generation.generator import ( @@ -32,7 +34,6 @@ from fairseq2.typing import finaloverride, override -@final class SamplingSequenceGenerator(SequenceGenerator): """Represents a sequence generator based on sampling.""" @@ -160,7 +161,54 @@ def __call__( return SequenceGeneratorOutput(hypotheses) -@final +class SpeculativeSamplingSequenceGenerator(SamplingSequenceGenerator): + """Represents a speculative sequence generator based on sampling.""" + + def __init__( + self, + model: DecoderModel, + model_draft: DecoderModel, + *args, + **kwargs, + ) -> None: + """ + :param model: + The decoder model to use for generation. + :param model_draft: + The decoder model to use to generate draft tokens. + """ + super().__init__(model, *args, **kwargs) + self.model_draft = model_draft + + @torch.inference_mode() + def __call__( + self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] + ) -> SequenceGeneratorOutput: + op = _SpeculativeSamplingSequenceGeneratorOp( + self.model, + self.model_draft, + prompt_seqs, + prompt_padding_mask, + self.sampler, + self.num_gens, + self.min_gen_len, + self.max_gen_len, + self.max_seq_len, + self.echo_prompt, + self.compute_scores, + self.normalize_scores, + self.temperature, + self.unk_penalty, + self.len_penalty, + self.step_processors, + self._step_hooks, + ) + + hypotheses = op() + + return SequenceGeneratorOutput(hypotheses) + + class SamplingSeq2SeqGenerator(Seq2SeqGenerator): """Represents a sequence-to-sequence generator based on sampling.""" @@ -259,10 +307,18 @@ def __call__( prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask], ) -> Seq2SeqGeneratorOutput: + seq_len = dict() + timer_result = dict() + # (P, S) - encoder_output, encoder_padding_mask = self.model.encode( + torch.cuda.synchronize() + start_time = time.time() + (encoder_output, encoder_padding_mask, gpu_util), src_seq_len = self.model.encode( source_seqs, source_padding_mask ) + torch.cuda.synchronize() + timer_result["Encoder"] = (time.time()-start_time)*1000 + seq_len["Encoder"] = [src_seq_len, encoder_output.shape[1], 1] if source_padding_mask is None: max_source_len = source_seqs.size(1) @@ -285,6 +341,8 @@ def __call__( f"`min_gen_len` must be less than or equal to `max_gen_len` ({max_gen_len}), but is {self.min_gen_len} instead. Adjust your `max_gen_len` argument." ) + torch.cuda.synchronize() + start_time = time.time() op = _SamplingSeq2SeqGeneratorOp( self.model, encoder_output, @@ -306,9 +364,113 @@ def __call__( self._step_hooks, ) - hypotheses = op() + hypotheses, decoding_step, gpu_util2 = op() + torch.cuda.synchronize() + timer_result["Decoder"] = (time.time()-start_time)*1000 + + seq_len["Decoder"] = [op.min_prompt_len-1] + seq_len["Decoder"] += [np.average([h[0].seq.shape[0] for h in hypotheses]), decoding_step] + return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask), timer_result, seq_len, np.average(gpu_util+gpu_util2) - return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask) + +class SpeculativeSamplingSeq2SeqGenerator(SamplingSeq2SeqGenerator): + """Represents a sequence-to-sequence generator based on sampling.""" + + model_draft: EncoderDecoderModel + + def __init__( + self, + model: EncoderDecoderModel, + model_draft: EncoderDecoderModel, + k_speculate: int, + *args, + **kwargs, + ) -> None: + """ + :param model: + The encoder-decoder model to use for generation. + :param model_draft: + The encoder-decoder draft model to generate draft tokens. + """ + super().__init__(model, *args, **kwargs) + self.model_draft = model_draft + self.k_speculate = k_speculate + + @finaloverride + @torch.inference_mode() + def __call__( + self, + source_seqs: Tensor, + source_padding_mask: Optional[PaddingMask], + prompt_seqs: Tensor, + prompt_padding_mask: Optional[PaddingMask], + ) -> Seq2SeqGeneratorOutput: + seq_len = dict() + timer_result = dict() + + # (P, S) + torch.cuda.synchronize() + start_time = time.time() + (encoder_output, encoder_padding_mask, gpu_util), src_seq_len = self.model.encode( + source_seqs, source_padding_mask + ) + torch.cuda.synchronize() + timer_result["Encoder"] = (time.time()-start_time)*1000 + seq_len["Encoder"] = [src_seq_len, encoder_output.shape[1], 1] + + if source_padding_mask is None: + max_source_len = source_seqs.size(1) + else: + max_source_len = int(source_padding_mask.seq_lens.max()) + + a_term, b_term = self.max_gen_len + + # In seq2seq generation, the maximum generation length is relative to + # the source sequence length. + max_gen_len = int(a_term * max_source_len + b_term) + + if max_gen_len < 1: + raise ValueError( + f"`max_gen_len` must be greater than or equal to 1, but is {max_gen_len} instead. Adjust your `max_gen_len` argument." + ) + + if self.min_gen_len > max_gen_len: + raise ValueError( + f"`min_gen_len` must be less than or equal to `max_gen_len` ({max_gen_len}), but is {self.min_gen_len} instead. Adjust your `max_gen_len` argument." + ) + + torch.cuda.synchronize() + start_time = time.time() + op = _SpeculativeSamplingSeq2SeqGeneratorOp( + self.model, + self.model_draft, + self.k_speculate, + encoder_output, + encoder_padding_mask, + prompt_seqs, + prompt_padding_mask, + self.sampler, + self.num_gens, + self.min_gen_len, + max_gen_len, + self.max_seq_len, + self.echo_prompt, + self.compute_scores, + self.normalize_scores, + self.temperature, + self.unk_penalty, + self.len_penalty, + self.step_processors, + self._step_hooks, + ) + + hypotheses, decoding_step, gpu_util2 = op() + torch.cuda.synchronize() + timer_result["Decoder"] = (time.time()-start_time)*1000 + + seq_len["Decoder"] = [op.min_prompt_len-1] + seq_len["Decoder"] += [np.average([h[0].seq.shape[0] for h in hypotheses]), decoding_step] + return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask), timer_result, seq_len, np.average(gpu_util+gpu_util2) class Sampler(ABC): @@ -543,10 +705,16 @@ def __init__( self.output = [[] for _ in range(num_prompts)] def __call__(self) -> List[List[Hypothesis]]: - self._prepare_state() + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_utils.append(gpu_util) + decoding_step = 0 for self.step_nr in range(self.min_prompt_len, self.max_seq_len): - if not self._step(): + output, gpu_util = self._step() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: break if self.compute_scores: @@ -554,12 +722,13 @@ def __call__(self) -> List[List[Hypothesis]]: for hypotheses in self.output: hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] - return self.output + return self.output, decoding_step, gpu_utils def _prepare_state(self) -> None: # Fast-forward to the first step that needs to be generated. + gpu_util = [] if self.min_prompt_len > 1: - self._prefill() + gpu_util = self._prefill() # Fan out the state to `num_prompts` x `num_gens`. if self.num_gens > 1: @@ -573,10 +742,12 @@ def _prepare_state(self) -> None: self._reorder_state(fan_out) + return gpu_util + def _prefill(self) -> None: prefill_len = self.min_prompt_len - model_output = self._decode(self.seqs[:, : prefill_len - 1]) + model_output, gpu_util = self._decode(self.seqs[:, : prefill_len - 1]) self.state_bag.increment_step_nr(prefill_len - 1) @@ -610,9 +781,11 @@ def _prefill(self) -> None: for hook in self.step_hooks.values(): hook(self.prompt_indices, seqs, step_scores, prefill=True) + return gpu_util + def _step(self) -> bool: # Generate the next step output. - model_output = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr]) + model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr]) self.state_bag.increment_step_nr() @@ -713,13 +886,13 @@ def _step(self) -> bool: # No sequence left, we can return. if len(active_seq_indices) == 0: - return False + return False, gpu_util # Otherwise, remove the sequences that have reached EOS from the # state and continue generating the remaining ones. self._reorder_state(active_seq_indices) - return True + return True, gpu_util @abstractmethod def _decode(self, seqs: Tensor) -> SequenceModelOutput: @@ -842,6 +1015,142 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: return self.model.project(decoder_output, decoder_padding_mask) +class _SpeculativeSamplingSequenceGeneratorOp(_SamplingSequenceGeneratorOpBase): + model: DecoderModel + model_draft: DecoderModel + + def __init__( + self, + model: DecoderModel, + model_draft: DecoderModel, + prompt_seqs: Tensor, + prompt_padding_mask: Optional[PaddingMask], + sampler: Sampler, + num_gens: int, + min_gen_len: int, + max_gen_len: int, + max_seq_len: int, + echo_prompt: bool, + compute_scores: bool, + normalize_scores: bool, + temperature: float, + unk_penalty: float, + len_penalty: float, + step_processors: Sequence[StepProcessor], + step_hooks: Dict[int, StepHook], + ) -> None: + super().__init__( + prompt_seqs, + prompt_padding_mask, + sampler, + model.vocab_info, + num_gens, + min_gen_len, + max_gen_len, + max_seq_len, + echo_prompt, + compute_scores, + normalize_scores, + temperature, + unk_penalty, + len_penalty, + step_processors, + step_hooks, + ) + + self.model = model + self.model_draft = model_draft + + def __call__(self) -> List[List[Hypothesis]]: + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_util = self._prepare_state_draft() + gpu_utils.append(gpu_util) + + decoding_step = 0 + for self.step_nr in range(self.min_prompt_len, self.max_seq_len): + output, gpu_util = self._step() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: + break + + if self.compute_scores: + # Sort the hypotheses by their scores before returning. + for hypotheses in self.output: + hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] + + return self.output, decoding_step, gpu_utils + + def _prepare_state_draft(self) -> None: + # Fast-forward to the first step that needs to be generated. + gpu_util = [] + if self.min_prompt_len > 1: + gpu_util = self._prefill_draft() + + # Fan out the state to `num_prompts` x `num_gens`. + if self.num_gens > 1: + num_prompts = self.seqs.size(0) + + # (P) + fan_out = torch.arange(num_prompts, device=self.seqs.device) + + # (P) -> (P x G) + fan_out = repeat_interleave(fan_out, dim=0, repeat=self.num_gens) + + self._reorder_state(fan_out) + + return gpu_util + + def _prefill_draft(self) -> None: + prefill_len = self.min_prompt_len + + model_output, gpu_util = self._decode_draft(self.seqs[:, : prefill_len - 1]) + + self.state_bag_draft.increment_step_nr(prefill_len - 1) + + if self.step_scores is not None: + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, 1:prefill_len].unsqueeze(-1) + ) + + # Bootstrap the step scores. + # (P, S_prm - 1) + self.step_scores[:, 1:prefill_len] = prompt_scores.squeeze(-1) + + if self.step_hooks: + seqs = self.seqs[:, :prefill_len] + + if self.step_scores is None: + step_scores = None + else: + step_scores = self.step_scores[:, :prefill_len] + + for hook in self.step_hooks.values(): + hook(self.prompt_indices, seqs, step_scores, prefill=True) + + return gpu_util + + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: + decoder_output, decoder_padding_mask = self.model_draft.decode( + seqs, + None, # We never use PAD in incremental decoding. + state_bag=self.state_bag, + ) + + return self.model_draft.project(decoder_output, decoder_padding_mask) + + class _SamplingSeq2SeqGeneratorOp(_SamplingSequenceGeneratorOpBase): model: EncoderDecoderModel @@ -894,7 +1203,7 @@ def __init__( @override def _decode(self, seqs: Tensor) -> SequenceModelOutput: - decoder_output, decoder_padding_mask = self.model.decode( + decoder_output, decoder_padding_mask, gpu_util = self.model.decode( seqs, None, # We never use PAD in incremental decoding. self.encoder_output, @@ -902,7 +1211,7 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: state_bag=self.state_bag, ) - return self.model.project(decoder_output, decoder_padding_mask) + return self.model.project(decoder_output, decoder_padding_mask), np.average(gpu_util) @override def _reorder_state(self, new_order: Tensor) -> None: @@ -919,3 +1228,322 @@ def _reorder_state(self, new_order: Tensor) -> None: self.encoder_padding_mask = PaddingMask( encoder_seq_lens, batch_seq_len=self.encoder_output.size(1) ) + +class _SpeculativeSamplingSeq2SeqGeneratorOp(_SamplingSeq2SeqGeneratorOp): + model_draft: EncoderDecoderModel + k_speculate: int + state_bag_draft: IncrementalStateBag + + def __init__( + self, + model: EncoderDecoderModel, + model_draft: EncoderDecoderModel, + k_speculate: int, + encoder_output: Tensor, + encoder_padding_mask: Optional[PaddingMask], + *args, + **kwargs, + ) -> None: + _SamplingSeq2SeqGeneratorOp.__init__(self, model, encoder_output, encoder_padding_mask, *args, **kwargs) + self.model_draft = model_draft + self.k_speculate = k_speculate + self.state_bag_draft = IncrementalStateBag(self.max_seq_len) + self.step_nr_draft = None + + def __call__(self) -> List[List[Hypothesis]]: + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_util = self._prepare_state_draft() + gpu_utils.append(gpu_util) + + decoding_step = 0 + self.step_nr = self.min_prompt_len + bs = self.seqs.shape[0] + vocab_indices_draft = torch.zeros(bs, self.k_speculate, device=self.seqs.device, dtype=self.seqs.dtype) + probs_draft = torch.zeros(bs, self.k_speculate, self.model.final_proj.output_dim, device=self.seqs.device, dtype=self.encoder_output.dtype) + output_draft = [False] * self.k_speculate + while self.step_nr < self.max_seq_len: + # Run draft model to generate draft tokens + for idx_draft, self.step_nr_draft in enumerate(range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len))): + output_draft[idx_draft], probs_draft[:, idx_draft ,:], vocab_indices_draft[:, idx_draft], gpu_util_draft = self._step_draft() + gpu_utils.append(gpu_util_draft) + decoding_step+=1 + # TODO: Move this to post-acceptance logic + # if not output_draft: + # break + + num_draft_tokens = self.step_nr_draft - self.step_nr + 1 + + # Run draft tokens by main model + # TODO: receive probabilities or tokens from main + probs_main, vocab_indices_main, gpu_util_main = self._verify_main() + gpu_utils.append(gpu_util_main) + + # Count how many draft tokens are accepted by main model + # TODO: run verification rather than hard coding number of accepted tokens + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = probs_draft[:, torch.arange(0, num_draft_tokens, device=probs_draft.device), :] + p = torch.gather(p, dim=2, index=vocab_indices_draft[:, :num_draft_tokens].unsqueeze(-1)) + q = probs_main[:, torch.arange(0, num_draft_tokens, device=probs_main.device), :] + q = torch.gather(q, dim=2, index=vocab_indices_main[:, :num_draft_tokens].unsqueeze(-1)) + accept_draft_prob = torch.minimum(torch.ones(()), q[:, :, :num_draft_tokens]/ p).squeeze(dim=-1) + is_accepted = torch.rand_like(accept_draft_prob) > accept_draft_prob + + rejected_locations = torch.argmax((~is_accepted).int(), dim=-1) + no_false = torch.all(is_accepted, dim=1) + rejected_locations[no_false] = -1 + min_rejected_locations = torch.min(rejected_locations) + + if min_rejected_locations == -1: # All draft tokens have been accepted + # TODO: Remove Hack + min_rejected_locations = 0 + + if True: + accept_length = min_rejected_locations + p = probs_draft[:, accept_length, :] + q = probs_main[:, accept_length, :] + probs_new = q - p + probs_new = torch.where(probs_new > 0, probs_new, 0.0) + probs_new = probs_new / probs_new.sum() + # (N) + vocab_indices = self.sampler(probs_new) + self.seqs[:, self.step_nr + accept_length] = vocab_indices + + self.step_nr += accept_length+1 + + # TODO: move post decoding processing here (e.g., self.step_scores, hooks, etc.) + self.state_bag.increment_step_nr(- (num_draft_tokens - accept_length - 1)) + self.state_bag_draft.increment_step_nr(- (num_draft_tokens - accept_length - 1)) + + if self.compute_scores: + # Sort the hypotheses by their scores before returning. + for hypotheses in self.output: + hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] + + return self.output, decoding_step, gpu_utils + + def _prepare_state_draft(self) -> None: + # Fast-forward to the first step that needs to be generated. + gpu_util = [] + if self.min_prompt_len > 1: + gpu_util = self._prefill_draft() + + # Fan out the state to `num_prompts` x `num_gens`. + if self.num_gens > 1: + num_prompts = self.seqs.size(0) + + # (P) + fan_out = torch.arange(num_prompts, device=self.seqs.device) + + # (P) -> (P x G) + fan_out = repeat_interleave(fan_out, dim=0, repeat=self.num_gens) + + self._reorder_state(fan_out) + + return gpu_util + + def _verify_main(self) -> None: + model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr_draft]) + self.state_bag.increment_step_nr(self.step_nr_draft - (self.step_nr - 1)) + + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # TODO: return probs regardless of self.step_scores? + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, self.step_nr+1: self.step_nr_draft+1].unsqueeze(-1) + ) + + # Copied from self._draft_step() + # If we are generating the last possible step, force it to be EOS + # regardless of its score. + # TODO: move this to after verification? + # Process `probs` in-place if requested. + for processor in self.step_processors: + processor(self.seqs[:, : self.step_nr_draft], probs) + + # Apply UNK penalty. + if self.unk_idx is not None: + probs[:, self.unk_idx] -= self.unk_penalty + + # Never allow PAD. + if self.pad_idx is not None: + probs[:, self.pad_idx] = 0 + + # Do not allow EOS till we reach the minimum sequence length. + if self.step_nr_draft < self.min_seq_len - 1: + probs[:, self.eos_idx] = 0 + + # (N) + vocab_indices = self.sampler(probs) + + # EOS mask of the current step. + # (N) + eos_mask = vocab_indices == self.eos_idx + + # Ignore the generated indices for the prompt sequences. + if self.step_nr_draft < self.max_prompt_len: + assert self.prompt_mask is not None + + # (N) + mask = self.prompt_mask[:, self.step_nr_draft] + + # Override the generated indices. + vocab_indices[mask] = self.seqs[mask, self.step_nr_draft] + + # Ignore EOS in the prompt sequences. + eos_mask[mask] = False + else: + self.prompt_mask = None # Not needed anymore, release. + + # TODO: commenting for now. Move to accept step. + ## Record the current step. + # self.seqs[:, self.step_nr - 1 : self.step_nr_draft] = vocab_indices + + return probs, vocab_indices, gpu_util + + def _prefill_draft(self) -> None: + prefill_len = self.min_prompt_len + + model_output, gpu_util = self._decode_draft(self.seqs[:, : prefill_len - 1]) + + self.state_bag_draft.increment_step_nr(prefill_len - 1) + + return gpu_util + + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: + decoder_output, decoder_padding_mask, gpu_util = self.model_draft.decode( + seqs, + None, # We never use PAD in incremental decoding. + self.encoder_output, + self.encoder_padding_mask, + state_bag=self.state_bag_draft, + ) + + return self.model_draft.project(decoder_output, decoder_padding_mask), np.average(gpu_util) + + def _step_draft(self) -> bool: + # Generate the next step output. + model_output, gpu_util = self._decode_draft(self.seqs[:, self.step_nr_draft - 1 : self.step_nr_draft]) + + self.state_bag_draft.increment_step_nr() + + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # (N, 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # (N, 1, V) -> (N, V) + probs.squeeze_(1) + + # If we are generating the last possible step, force it to be EOS + # regardless of its score. + if self.step_nr_draft == self.max_seq_len - 1: + batch_size = self.seqs.size(0) + + # (N) + vocab_indices = self.seqs.new_full((batch_size,), self.eos_idx) + else: + # Process `probs` in-place if requested. + for processor in self.step_processors: + processor(self.seqs[:, : self.step_nr_draft], probs) + + # Apply UNK penalty. + if self.unk_idx is not None: + probs[:, self.unk_idx] -= self.unk_penalty + + # Never allow PAD. + if self.pad_idx is not None: + probs[:, self.pad_idx] = 0 + + # Do not allow EOS till we reach the minimum sequence length. + if self.step_nr_draft < self.min_seq_len - 1: + probs[:, self.eos_idx] = 0 + + # (N) + vocab_indices = self.sampler(probs) + + # EOS mask of the current step. + # (N) + eos_mask = vocab_indices == self.eos_idx + + # Ignore the generated indices for the prompt sequences. + if self.step_nr_draft < self.max_prompt_len: + assert self.prompt_mask is not None + + # (N) + mask = self.prompt_mask[:, self.step_nr_draft] + + # Override the generated indices. + vocab_indices[mask] = self.seqs[mask, self.step_nr_draft] + + # Ignore EOS in the prompt sequences. + eos_mask[mask] = False + else: + self.prompt_mask = None # Not needed anymore, release. + + # Record the current step. + self.seqs[:, self.step_nr_draft] = vocab_indices + + # TODO: move to apply on accepted tokens? + if self.step_scores is not None: + # (N, 1) + scores = torch.gather(probs, dim=-1, index=vocab_indices[:, None]) + + # Record the scores of the current step. + self.step_scores[:, self.step_nr_draft] = scores.squeeze(1) + + # TODO: move to apply on accepted tokens? + if self.step_hooks: + seqs = self.seqs[:, : self.step_nr_draft + 1] + + if self.step_scores is None: + step_scores = None + else: + step_scores = self.step_scores[:, : self.step_nr_draft + 1] + + for hook in self.step_hooks.values(): + hook(self.prompt_indices, seqs, step_scores, prefill=False) + + # TODO: move to apply on accepted tokens? + # Retrieve the indices of the sequences that have reached EOS. + # (F, 1) + eos_seq_indices = eos_mask.nonzero() + + # TODO: move to apply on accepted tokens? + # If one or more sequences have reached EOS, move them to the output and + # continue generating the remaining sequences. + if len(eos_seq_indices) > 0: + # Move the sequences that have reached EOS to the output. + for seq_idx in eos_seq_indices: + self._finish_sequence(int(seq_idx)) + + # (N) + active_seq_mask = ~eos_mask + + # (N - F, 1) -> (N - F) + active_seq_indices = active_seq_mask.nonzero().squeeze(-1) + + # No sequence left, we can return. + if len(active_seq_indices) == 0: + return False, probs, vocab_indices, gpu_util + + # Otherwise, remove the sequences that have reached EOS from the + # state and continue generating the remaining ones. + ## TODO: move that on main model only? + # self._reorder_state(active_seq_indices) + # self.state_bag_draft.reorder(active_seq_indices) + + return True, probs, vocab_indices, gpu_util diff --git a/src/fairseq2/nn/incremental_state.py b/src/fairseq2/nn/incremental_state.py index 935219d74..13f336553 100644 --- a/src/fairseq2/nn/incremental_state.py +++ b/src/fairseq2/nn/incremental_state.py @@ -74,6 +74,11 @@ def increment_step_nr(self, value: int = 1) -> None: self.step_nr = step_nr + # Update seq_len of attention modules + for module in self._module_states.values(): + if hasattr(module, "seq_len"): + module.seq_len = step_nr + def get_state(self, m: Module, kls: Type[T]) -> Optional[T]: """Get the state of ``m`` if present in the bag. diff --git a/src/fairseq2/nn/module_list.py b/src/fairseq2/nn/module_list.py index ea0ff6a6b..76672bd7e 100644 --- a/src/fairseq2/nn/module_list.py +++ b/src/fairseq2/nn/module_list.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Iterator, Optional, final +from typing import Iterable, Iterator, List, Optional, Union, final import torch from torch.nn import Module @@ -39,10 +39,10 @@ class ModuleList(ModuleListBase): ... x = layer(x) """ - drop_p: float + _drop_p: List[float] def __init__( - self, modules: Optional[Iterable[Module]] = None, *, drop_p: float = 0.0 + self, modules: Optional[Iterable[Module]] = None, *, drop_p: Union[float, Iterable[float]] = 0.0 ) -> None: """ :param modules: @@ -56,20 +56,36 @@ def __init__( def drop_iter(self) -> Iterator[Module]: """Return an iterator that drops a random set of submodules.""" - if self.drop_p > 0.0 and self.training: + if any(drop_p > 0.0 for drop_p in self.drop_p) and self.training: prob_dist = torch.rand(len(self), device="cpu", dtype=torch.float32) else: prob_dist = None for idx, m in enumerate(super().__iter__()): - if prob_dist is None or prob_dist[idx] > self.drop_p: + if prob_dist is None or prob_dist[idx] > self.drop_p[idx]: yield m def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() - if self.drop_p > 0.0: + if any(drop_p > 0.0 for drop_p in self.drop_p): s = f"{s}, drop_p={self.drop_p}" return s + + @property + def drop_p(self): + """Get probability of dropping each layer.""" + return self._drop_p + + @drop_p.setter + def drop_p(self, drop_p: Union[float, Iterable[float]]): + """Set probability of dropping layers using either a single value or a list of values.""" + if isinstance(drop_p, Iterable): + assert len(drop_p) == len(self) + self._drop_p = drop_p + elif isinstance(drop_p, float): + self._drop_p = [drop_p] * len(self) + else: + raise ValueError(f"Unsupported type for drop rate {drop_p}. Expecting either float or list of floats.") diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index 8677bfc3e..ba4baba3f 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -215,7 +215,7 @@ def forward( *, state_bag: Optional[IncrementalStateBag] = None, ) -> Tuple[Tensor, Optional[PaddingMask]]: - if self._layer_output_hooks and self.layers.drop_p > 0.0: + if self._layer_output_hooks and any(drop_p > 0.0 for drop_p in self.layers.drop_p): raise RuntimeError( "The layer output hooks cannot be run when LayerDrop is enabled." ) @@ -239,11 +239,16 @@ def forward( state_bag=state_bag, ) gpu_util.append(torch.cuda.utilization(torch.cuda.current_device())) - + + early_exit = False for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): + early_exit = True break + if early_exit: + break + if self.layer_norm is not None: seqs = self.layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index 7814120d0..d6fba2f8c 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -179,7 +179,7 @@ def __init__( def forward( self, seqs: Tensor, padding_mask: Optional[PaddingMask] ) -> Tuple[Tensor, Optional[PaddingMask]]: - if self._layer_output_hooks and self.layers.drop_p > 0.0: + if self._layer_output_hooks and any(drop_p > 0.0 for drop_p in self.layers.drop_p): raise RuntimeError( "The layer output hooks cannot be run when LayerDrop is enabled." ) @@ -197,10 +197,15 @@ def forward( seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask) gpu_util.append(torch.cuda.utilization(torch.cuda.current_device())) + early_exit = False for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): + early_exit = True break + if early_exit: + break + if self.layer_norm is not None: seqs = self.layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/multihead_attention.py b/src/fairseq2/nn/transformer/multihead_attention.py index 43e2ac878..e02924490 100644 --- a/src/fairseq2/nn/transformer/multihead_attention.py +++ b/src/fairseq2/nn/transformer/multihead_attention.py @@ -632,11 +632,13 @@ def __init__(self, k: Tensor, v: Tensor, max_seq_len: int) -> None: @finaloverride def append(self, k: Tensor, v: Tensor) -> None: pos = self.seq_len + assert k.shape[2] == v.shape[2] + append_len = k.shape[2] - self.k[:, :, pos : pos + 1] = k - self.v[:, :, pos : pos + 1] = v + self.k[:, :, pos : pos + append_len] = k + self.v[:, :, pos : pos + append_len] = v - self.seq_len += 1 + self.seq_len += append_len @finaloverride def get(self) -> Tuple[Tensor, Tensor]: diff --git a/src/fairseq2/utils/early_exit_loss.py b/src/fairseq2/utils/early_exit_loss.py new file mode 100644 index 000000000..17caad34b --- /dev/null +++ b/src/fairseq2/utils/early_exit_loss.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import numpy as np +import torch +from enum import Enum +from typing import Dict, List, Optional +from torch import Tensor + +from fairseq2.nn.padding import PaddingMask + +hidden_states_dict: Dict[int, Tensor] = {} + +def hook( + layer_idx: int, + layer_output: Tensor, + layer_padding_mask: Optional[PaddingMask], + num_layers: int, +) -> bool: + global hidden_states_dict + + # TODO: handle only doing early exit loss on specific layers + hidden_states_dict[layer_idx] = layer_output + + return True + + +class LossScaleType(str, Enum): + ONE = "one" + L = "l" + L2 = "l2" + SUM_L = "sum_l" + SUM_L2 = "sum_l2" + INV_L = "inv_l" + SQRT_L = "sqrt_l" + INV_SQRT_L = "inv_sqrt_l" + +def early_exit_loss(model, hidden_states_dict, batch, loss_fn, e_scale: float=20.0, loss_scale_type=LossScaleType.ONE): + hidden_states = tuple(hidden_states_dict.values()) + hidden_layer_ids = tuple(hidden_states_dict.keys()) + + losses_early = [] + for hidden_state in hidden_states: + logits_early = model.module.model.final_proj(model.module.model.text_decoder.layer_norm(hidden_state)) + # FIXME: assuming that the other output is None, which is not always the case + loss_early = loss_fn(batch, *(logits_early, None)) + losses_early.append(loss_early) + + losses_early = torch.stack(losses_early, dim=0) + losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.module.model.text_decoder.layers), loss_scale_type, e_scale) + + return torch.sum(losses_scales * losses_early) + +def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float): + match loss_scale_type: + case LossScaleType.ONE: + loss_scales = torch.ones(len(layer_ids)) + case LossScaleType.L: + loss_scales = torch.Tensor(layer_ids+1) + case LossScaleType.L2: + loss_scales = torch.Tensor((layer_ids+1)**2) + case LossScaleType.SUM_L: + # TODO: should we change to sum 0:i ? Perhaps create a new scale_type + loss_scales = torch.cumsum(layer_ids+1, dim=0) + case LossScaleType.SUM_L2: + # TODO: should we change to sum 0:i ? Perhaps create a new scale_type + loss_scales = torch.cumsum((layer_ids+1)**2, dim=0) + case LossScaleType.SQRT_L: + loss_scales = torch.sqrt(layer_ids+1) + case LossScaleType.INV_L: + loss_scales = 1.0 / (layer_ids+1) + case LossScaleType.INV_SQRT_L: + loss_scales = 1.0 / torch.sqrt(layer_ids+1) + case _: + raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") + + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + # normalize loss scales to ensure that their sum is 1.0 + loss_scales = loss_scales / torch.sum(loss_scales) + assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) + + return loss_scales + +class EarlyExitCurriculumType(str, Enum): + NONE = "none" + ROTATIONAL = "rot" + GRADUAL = "gradual" + +def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): + match early_exit_curriculum: + case EarlyExitCurriculumType.NONE: + return None + + case EarlyExitCurriculumType.ROTATIONAL: + return RotationalEarlyExitCurriculum(*args, **kwargs) + + case EarlyExitCurriculumType.GRADUAL: + return GradualEarlyExitCurriculum(*args, **kwargs) + + case _: + raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") + + +# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. +class EarlyExitCurriculum(): + def __init__(self, output_hidden_states, max_steps, verbose=False): + self._init_output_hidden_states = output_hidden_states + self.output_hidden_states = output_hidden_states + self.verbose = verbose + self.max_steps = max_steps + + def step(self): + pass + + def get(self): + return self.output_hidden_states + +class RotationalEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) + + def step(self): + self.output_hidden_states = np.roll(self.output_hidden_states, -1) + if self.verbose: + print(f"Updating self.output_hidden_states to {self.output_hidden_states}.") + +class GradualEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) + self._step = 0 + + def step(self): + percent_trained = self._step / self.max_steps + n_layers = len(self.output_hidden_states) + for layer_index in range(len(self.output_hidden_states)): + # TODO: replace 2 with an argument + should_train = (percent_trained * 2) >= ((n_layers - 1 - layer_index) / (n_layers - 1)) + self.output_hidden_states[layer_index] = should_train + + # TODO: move this to step() in parent class? + # TODO: how to ensure we always call parent step() in derived class? + self._step += 1 + if self.verbose: + print(f"Updating self.output_hidden_states to {self.output_hidden_states}.") diff --git a/src/fairseq2/utils/scales.py b/src/fairseq2/utils/scales.py new file mode 100644 index 000000000..cad198ffb --- /dev/null +++ b/src/fairseq2/utils/scales.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Optional +import math + +def slice_str_to_array(slice_str: str, length: int): + """ + Converts a slice string to a boolean array where True indicates the index is included in the slice. + :param slice_str: + A string representing a slice. The format should be "start:end:step", where + each part is optional. Examples include "1:5", ":5", "::2". + :param length: + The length of the resulting boolean array. This is the total number of elements + in the array, and should be a non-negative integer. + :return: + A list of boolean values where each element is True if its index falls within the specified slice. + :raises ValueError: + If any part of `slice_str` is not convertible to an integer. + Examples: + >>> slice_str_to_array("1:5", 10) + [False, True, True, True, True, False, False, False, False, False] + >>> slice_str_to_array("::2", 5) + [True, False, True, False, True] + >>> slice_str_to_array("3:", 5) + [False, False, False, True, True] + """ + # Parse the slice string + parts = slice_str.split(':') + start, end, step = None, None, None + + if len(parts) == 1 and parts[0] != '': + start = int(parts[0]) + elif len(parts) == 2: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + elif len(parts) == 3: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + step = int(parts[2]) if parts[2] != '' else None + + # Create a boolean array based on the slice + result = [False] * length + slice_indices = range(start if start is not None else 0, + end if end is not None else length, + step if step is not None else 1) + + for i in slice_indices: + if 0 <= i < length: + result[i] = True + + return result + +class ScaleType(str, Enum): + UNIFORM = "uniform" + EXP = "exp" + LINEAR = "linear" + LOG = "log" + SIN = "sin" + SIGMOID = "sigmoid" + STEP = "step" + +def get_scale(scale_type: ScaleType, scale_period: int, idx: int): + """ + Calculates a scaling factor based on the specified scale type, scale period, and value. + :param scale_type: + A member of the :class:`ScaleType` enum that specifies the type of scaling to apply. + :param scale_period: + An integer representing the period over which the scaling is applied. This is used + as the denominator in scaling calculations to normalize the `val`. + :param idx: + An integer representing the current index for which the scaling factor is calculated. + This value should be within the range [0, scale_period]. + :return: + A float representing the scaling factor. This factor is calculated based on the `scale_type`. + The scaling factor is designed to be 0 when `val` is 0 and 1 when `val` is `scale_period`, + except for `ScaleType.UNIFORM` where it is always 1. + :raises ValueError: + If `scale_period` is 0, as division by zero in scaling calculations is not allowed. + Examples: + >>> get_scale(ScaleType.LINEAR, 10, 5) + 0.5 + >>> get_scale(ScaleType.EXP, 10, 3) + 0.2362900883445226 + >>> get_scale(ScaleType.LOG, 10, 2) + 0.3562071871080222 + >>> get_scale(ScaleType.SIN, 10, 5) + 1.0 + >>> get_scale(ScaleType.SIGMOID, 10, 5) + 0.5 + """ + if scale_period == 0: + return 1 + + # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period + return { + ScaleType.UNIFORM: 1, + ScaleType.EXP: math.exp(idx * math.log(2) / scale_period) - 1, + ScaleType.LINEAR: idx / scale_period, + ScaleType.LOG: math.log(idx + 1) / math.log(scale_period + 1), + ScaleType.SIN: math.sin(0.5 * math.pi * idx / scale_period), + ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (idx / scale_period - 0.5))), + }[scale_type] + +def get_values(scale_type: ScaleType, scale_period: int, max_val: float= 0.0, slice_str: Optional[str] = None): + """ + Generates a list of values scaled according to the specified scale type and period, optionally filtered by a slice string. + :param scale_type: + A member of the :class:`ScaleType` enum that specifies the type of scaling to apply. + :param scale_period: + An integer representing the period over which the scaling is applied. This is used + to determine the number of values in the result list. + :param max_val: + A float representing the maximum possible value in the result list. Defaults to 0.0. + :param slice_str: + An optional string representing a slice of indices to include in the scaling. If provided, + only indices that fall within the slice are scaled; others are set to 0.0. If None, all + indices are included. Defaults to None. + :return: + A list of floats where each element is a scaled value based on `scale_type`. The scaling + is applied only to indices specified by `slice_str`, and all values are guaranteed to be + between 0 and `max_val`. + :raises AssertionError: + If any calculated value is not within the range [0, `max_val`]. + Examples: + >>> get_values(ScaleType.LINEAR, 5, 10) + [0.0, 2.5, 5.0, 7.5, 10.0] + >>> get_values(ScaleType.EXP, 5, 10, "1:3") + [0.0, 2.371373705661655, 4.894348370484656, 0.0, 0.0] + >>> get_values(ScaleType.LOG, 5, 10, "0:5:2") + [0.0, 0.0, 5.0, 0.0, 10.0] + """ + vals = [] + has_val = slice_str_to_array(slice_str, scale_period) if slice_str else [True] * scale_period + + for idx in range(scale_period): + val = max_val * get_scale( + scale_type = scale_type, + scale_period = scale_period - 1, + idx = idx, + ) if has_val[idx] else 0.0 + assert val >= 0.0 and val <= max_val, f"val={val} should be between 0 and {max_val}" + vals.append(val) + + return vals