From 68e83eacf5bb369f19842993845dda63905718e8 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 8 Jul 2024 20:17:03 +0000 Subject: [PATCH 01/19] set probabilities to be a list --- src/fairseq2/nn/module_list.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/nn/module_list.py b/src/fairseq2/nn/module_list.py index ea0ff6a6b..edd1bc2bf 100644 --- a/src/fairseq2/nn/module_list.py +++ b/src/fairseq2/nn/module_list.py @@ -52,24 +52,24 @@ def __init__( """ super().__init__(modules) - self.drop_p = drop_p + self.drop_p = [drop_p] * len(modules) def drop_iter(self) -> Iterator[Module]: """Return an iterator that drops a random set of submodules.""" - if self.drop_p > 0.0 and self.training: + if any(drop_p > 0.0 for drop_p in self.drop_p) and self.training: prob_dist = torch.rand(len(self), device="cpu", dtype=torch.float32) else: prob_dist = None for idx, m in enumerate(super().__iter__()): - if prob_dist is None or prob_dist[idx] > self.drop_p: + if prob_dist is None or prob_dist[idx] > self.drop_p[idx]: yield m def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() - if self.drop_p > 0.0: + if any(drop_p > 0.0 for drop_p in self.drop_p): s = f"{s}, drop_p={self.drop_p}" return s From 671779d5dea0ca45e399b72153809bf40568bd4f Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 8 Jul 2024 22:36:48 +0000 Subject: [PATCH 02/19] use setter wrapper for drop_p --- src/fairseq2/nn/module_list.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/nn/module_list.py b/src/fairseq2/nn/module_list.py index edd1bc2bf..76672bd7e 100644 --- a/src/fairseq2/nn/module_list.py +++ b/src/fairseq2/nn/module_list.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Iterator, Optional, final +from typing import Iterable, Iterator, List, Optional, Union, final import torch from torch.nn import Module @@ -39,10 +39,10 @@ class ModuleList(ModuleListBase): ... x = layer(x) """ - drop_p: float + _drop_p: List[float] def __init__( - self, modules: Optional[Iterable[Module]] = None, *, drop_p: float = 0.0 + self, modules: Optional[Iterable[Module]] = None, *, drop_p: Union[float, Iterable[float]] = 0.0 ) -> None: """ :param modules: @@ -52,7 +52,7 @@ def __init__( """ super().__init__(modules) - self.drop_p = [drop_p] * len(modules) + self.drop_p = drop_p def drop_iter(self) -> Iterator[Module]: """Return an iterator that drops a random set of submodules.""" @@ -73,3 +73,19 @@ def extra_repr(self) -> str: s = f"{s}, drop_p={self.drop_p}" return s + + @property + def drop_p(self): + """Get probability of dropping each layer.""" + return self._drop_p + + @drop_p.setter + def drop_p(self, drop_p: Union[float, Iterable[float]]): + """Set probability of dropping layers using either a single value or a list of values.""" + if isinstance(drop_p, Iterable): + assert len(drop_p) == len(self) + self._drop_p = drop_p + elif isinstance(drop_p, float): + self._drop_p = [drop_p] * len(self) + else: + raise ValueError(f"Unsupported type for drop rate {drop_p}. Expecting either float or list of floats.") From f3f3591a4672bed39027325c1a5c1373695cb0ef Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 9 Jul 2024 00:58:02 +0000 Subject: [PATCH 03/19] implement layer wise schedule for dropout --- src/fairseq2/utils/scales.py | 149 +++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 src/fairseq2/utils/scales.py diff --git a/src/fairseq2/utils/scales.py b/src/fairseq2/utils/scales.py new file mode 100644 index 000000000..cad198ffb --- /dev/null +++ b/src/fairseq2/utils/scales.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Optional +import math + +def slice_str_to_array(slice_str: str, length: int): + """ + Converts a slice string to a boolean array where True indicates the index is included in the slice. + :param slice_str: + A string representing a slice. The format should be "start:end:step", where + each part is optional. Examples include "1:5", ":5", "::2". + :param length: + The length of the resulting boolean array. This is the total number of elements + in the array, and should be a non-negative integer. + :return: + A list of boolean values where each element is True if its index falls within the specified slice. + :raises ValueError: + If any part of `slice_str` is not convertible to an integer. + Examples: + >>> slice_str_to_array("1:5", 10) + [False, True, True, True, True, False, False, False, False, False] + >>> slice_str_to_array("::2", 5) + [True, False, True, False, True] + >>> slice_str_to_array("3:", 5) + [False, False, False, True, True] + """ + # Parse the slice string + parts = slice_str.split(':') + start, end, step = None, None, None + + if len(parts) == 1 and parts[0] != '': + start = int(parts[0]) + elif len(parts) == 2: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + elif len(parts) == 3: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + step = int(parts[2]) if parts[2] != '' else None + + # Create a boolean array based on the slice + result = [False] * length + slice_indices = range(start if start is not None else 0, + end if end is not None else length, + step if step is not None else 1) + + for i in slice_indices: + if 0 <= i < length: + result[i] = True + + return result + +class ScaleType(str, Enum): + UNIFORM = "uniform" + EXP = "exp" + LINEAR = "linear" + LOG = "log" + SIN = "sin" + SIGMOID = "sigmoid" + STEP = "step" + +def get_scale(scale_type: ScaleType, scale_period: int, idx: int): + """ + Calculates a scaling factor based on the specified scale type, scale period, and value. + :param scale_type: + A member of the :class:`ScaleType` enum that specifies the type of scaling to apply. + :param scale_period: + An integer representing the period over which the scaling is applied. This is used + as the denominator in scaling calculations to normalize the `val`. + :param idx: + An integer representing the current index for which the scaling factor is calculated. + This value should be within the range [0, scale_period]. + :return: + A float representing the scaling factor. This factor is calculated based on the `scale_type`. + The scaling factor is designed to be 0 when `val` is 0 and 1 when `val` is `scale_period`, + except for `ScaleType.UNIFORM` where it is always 1. + :raises ValueError: + If `scale_period` is 0, as division by zero in scaling calculations is not allowed. + Examples: + >>> get_scale(ScaleType.LINEAR, 10, 5) + 0.5 + >>> get_scale(ScaleType.EXP, 10, 3) + 0.2362900883445226 + >>> get_scale(ScaleType.LOG, 10, 2) + 0.3562071871080222 + >>> get_scale(ScaleType.SIN, 10, 5) + 1.0 + >>> get_scale(ScaleType.SIGMOID, 10, 5) + 0.5 + """ + if scale_period == 0: + return 1 + + # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period + return { + ScaleType.UNIFORM: 1, + ScaleType.EXP: math.exp(idx * math.log(2) / scale_period) - 1, + ScaleType.LINEAR: idx / scale_period, + ScaleType.LOG: math.log(idx + 1) / math.log(scale_period + 1), + ScaleType.SIN: math.sin(0.5 * math.pi * idx / scale_period), + ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (idx / scale_period - 0.5))), + }[scale_type] + +def get_values(scale_type: ScaleType, scale_period: int, max_val: float= 0.0, slice_str: Optional[str] = None): + """ + Generates a list of values scaled according to the specified scale type and period, optionally filtered by a slice string. + :param scale_type: + A member of the :class:`ScaleType` enum that specifies the type of scaling to apply. + :param scale_period: + An integer representing the period over which the scaling is applied. This is used + to determine the number of values in the result list. + :param max_val: + A float representing the maximum possible value in the result list. Defaults to 0.0. + :param slice_str: + An optional string representing a slice of indices to include in the scaling. If provided, + only indices that fall within the slice are scaled; others are set to 0.0. If None, all + indices are included. Defaults to None. + :return: + A list of floats where each element is a scaled value based on `scale_type`. The scaling + is applied only to indices specified by `slice_str`, and all values are guaranteed to be + between 0 and `max_val`. + :raises AssertionError: + If any calculated value is not within the range [0, `max_val`]. + Examples: + >>> get_values(ScaleType.LINEAR, 5, 10) + [0.0, 2.5, 5.0, 7.5, 10.0] + >>> get_values(ScaleType.EXP, 5, 10, "1:3") + [0.0, 2.371373705661655, 4.894348370484656, 0.0, 0.0] + >>> get_values(ScaleType.LOG, 5, 10, "0:5:2") + [0.0, 0.0, 5.0, 0.0, 10.0] + """ + vals = [] + has_val = slice_str_to_array(slice_str, scale_period) if slice_str else [True] * scale_period + + for idx in range(scale_period): + val = max_val * get_scale( + scale_type = scale_type, + scale_period = scale_period - 1, + idx = idx, + ) if has_val[idx] else 0.0 + assert val >= 0.0 and val <= max_val, f"val={val} should be between 0 and {max_val}" + vals.append(val) + + return vals From 1c77e33647f9176114ff317374287bb04c6c0427 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 11 Jul 2024 04:55:27 +0000 Subject: [PATCH 04/19] ensure layer drop condition works --- src/fairseq2/nn/transformer/decoder.py | 2 +- src/fairseq2/nn/transformer/encoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index 8677bfc3e..fe5c733e2 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -215,7 +215,7 @@ def forward( *, state_bag: Optional[IncrementalStateBag] = None, ) -> Tuple[Tensor, Optional[PaddingMask]]: - if self._layer_output_hooks and self.layers.drop_p > 0.0: + if self._layer_output_hooks and any(drop_p > 0.0 for drop_p in self.layers.drop_p): raise RuntimeError( "The layer output hooks cannot be run when LayerDrop is enabled." ) diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index 7814120d0..f0c174c1e 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -179,7 +179,7 @@ def __init__( def forward( self, seqs: Tensor, padding_mask: Optional[PaddingMask] ) -> Tuple[Tensor, Optional[PaddingMask]]: - if self._layer_output_hooks and self.layers.drop_p > 0.0: + if self._layer_output_hooks and any(drop_p > 0.0 for drop_p in self.layers.drop_p): raise RuntimeError( "The layer output hooks cannot be run when LayerDrop is enabled." ) From 19005962a133261d42446c7db388abbb6e107d3f Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 11 Jul 2024 04:56:16 +0000 Subject: [PATCH 05/19] start early_exit_loss.py --- src/fairseq2/utils/early_exit_loss.py | 156 ++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 src/fairseq2/utils/early_exit_loss.py diff --git a/src/fairseq2/utils/early_exit_loss.py b/src/fairseq2/utils/early_exit_loss.py new file mode 100644 index 000000000..483810f0b --- /dev/null +++ b/src/fairseq2/utils/early_exit_loss.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import numpy as np +import torch +from enum import Enum +from typing import Dict, List, Optional +from torch import Tensor + +from fairseq2.nn.padding import PaddingMask + +hidden_states_dict: Dict[int, Tensor] = {} + +def hook( + layer_idx: int, + layer_output: Tensor, + layer_padding_mask: Optional[PaddingMask], + num_layers: int, +) -> bool: + global hidden_states_dict + + # TODO: handle only doing early exit loss on specific layers + hidden_states_dict[layer_idx] = layer_output + + return True + + +class LossScaleType(str, Enum): + ONE = "one" + L = "l" + SUM_L = "sum_l" + INV_L = "inv_l" + SQRT_L = "sqrt_l" + INV_SQRT_L = "inv_sqrt_l" + +def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): + batch_loss_fn = copy.deepcopy(loss_fn) + batch_loss_fn.reduction = "none" + + e = len(hidden_states_dict) + # List of e tensors with shape [b, s, d] + hidden_states = tuple(hidden_states_dict.values()) + hidden_layer_ids = tuple(hidden_states_dict.keys()) + # Shape: [e, b, s, d] + hidden_states_stacked = torch.stack(hidden_states) + # Shape: [e, b, s, out_dim] + logits_early = model.output(model.norm(hidden_states_stacked)) + logits_early = logits_early[..., :-1, :].contiguous() + # Shape: [e*b, s, out_dim] + logits_early = logits_early.flatten(0, 1) + logits_early = logits_early.transpose(1, 2) + # Shape: [e, b*s] + labels_repeated = labels.repeat(e, 1) + # Compute early losses: Shape: [e*b, s] + losses_early = batch_loss_fn(logits_early, labels_repeated) + # Shape: [e, b*s] + losses_early = losses_early.view(e, -1) + # Shape: [e] + s_unpadded = (labels != loss_fn.ignore_index).sum() + losses_early = losses_early.float().sum(-1) / s_unpadded + # Shape: [e] + # losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) + losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), loss_scale_type, e_scale) + + return torch.sum(losses_scales * losses_early) + +def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float): + match loss_scale_type: + case LossScaleType.ONE: + loss_scales = torch.ones(len(layer_ids)) + case LossScaleType.L: + loss_scales = torch.Tensor(layer_ids+1) + case LossScaleType.SUM_L: + # TODO: should we change to sum 0:i ? Perhaps create a new scale_type + loss_scales = torch.cumsum(layer_ids+1, dim=0) + case LossScaleType.SQRT_L: + loss_scales = torch.sqrt(layer_ids+1) + case LossScaleType.INV_L: + loss_scales = 1.0 / (layer_ids+1) + case LossScaleType.INV_SQRT_L: + loss_scales = 1.0 / torch.sqrt(layer_ids+1) + case _: + raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") + + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + # normalize loss scales to ensure that their sum is 1.0 + loss_scales = loss_scales / torch.sum(loss_scales) + assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) + + return loss_scales + +class EarlyExitCurriculumType(str, Enum): + NONE = "none" + ROTATIONAL = "rot" + GRADUAL = "gradual" + +def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): + match early_exit_curriculum: + case EarlyExitCurriculumType.NONE: + return None + + case EarlyExitCurriculumType.ROTATIONAL: + return RotationalEarlyExitCurriculum(*args, **kwargs) + + case EarlyExitCurriculumType.GRADUAL: + return GradualEarlyExitCurriculum(*args, **kwargs) + + case _: + raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") + + +# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. +class EarlyExitCurriculum(): + def __init__(self, output_hidden_states, max_steps, verbose=False): + self._init_output_hidden_states = output_hidden_states + self.output_hidden_states = output_hidden_states + self.verbose = verbose + self.max_steps = max_steps + + def step(self): + pass + + def get(self): + return self.output_hidden_states + +class RotationalEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) + + def step(self): + self.output_hidden_states = np.roll(self.output_hidden_states, -1) + if self.verbose: + print(f"Updating self.output_hidden_states to {self.output_hidden_states}.") + +class GradualEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) + self._step = 0 + + def step(self): + percent_trained = self._step / self.max_steps + n_layers = len(self.output_hidden_states) + for layer_index in range(len(self.output_hidden_states)): + # TODO: replace 2 with an argument + should_train = (percent_trained * 2) >= ((n_layers - 1 - layer_index) / (n_layers - 1)) + self.output_hidden_states[layer_index] = should_train + + # TODO: move this to step() in parent class? + # TODO: how to ensure we always call parent step() in derived class? + self._step += 1 + if self.verbose: + print(f"Updating self.output_hidden_states to {self.output_hidden_states}.") From 3d806c436145c294a0ede99c15c05ef3c209f4ec Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 16 Jul 2024 02:49:57 +0000 Subject: [PATCH 06/19] get early exit loss to work with seamless --- src/fairseq2/utils/early_exit_loss.py | 37 ++++++++------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/src/fairseq2/utils/early_exit_loss.py b/src/fairseq2/utils/early_exit_loss.py index 483810f0b..9ec0bef49 100644 --- a/src/fairseq2/utils/early_exit_loss.py +++ b/src/fairseq2/utils/early_exit_loss.py @@ -37,34 +37,19 @@ class LossScaleType(str, Enum): SQRT_L = "sqrt_l" INV_SQRT_L = "inv_sqrt_l" -def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): - batch_loss_fn = copy.deepcopy(loss_fn) - batch_loss_fn.reduction = "none" - - e = len(hidden_states_dict) - # List of e tensors with shape [b, s, d] +def early_exit_loss(model, hidden_states_dict, batch, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): hidden_states = tuple(hidden_states_dict.values()) hidden_layer_ids = tuple(hidden_states_dict.keys()) - # Shape: [e, b, s, d] - hidden_states_stacked = torch.stack(hidden_states) - # Shape: [e, b, s, out_dim] - logits_early = model.output(model.norm(hidden_states_stacked)) - logits_early = logits_early[..., :-1, :].contiguous() - # Shape: [e*b, s, out_dim] - logits_early = logits_early.flatten(0, 1) - logits_early = logits_early.transpose(1, 2) - # Shape: [e, b*s] - labels_repeated = labels.repeat(e, 1) - # Compute early losses: Shape: [e*b, s] - losses_early = batch_loss_fn(logits_early, labels_repeated) - # Shape: [e, b*s] - losses_early = losses_early.view(e, -1) - # Shape: [e] - s_unpadded = (labels != loss_fn.ignore_index).sum() - losses_early = losses_early.float().sum(-1) / s_unpadded - # Shape: [e] - # losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) - losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), loss_scale_type, e_scale) + + losses_early = [] + for hidden_state in hidden_states: + logits_early = model.module.model.final_proj(model.module.model.text_decoder.layer_norm(hidden_state)) + # FIXME: assuming that the other output is None, which is not always the case + loss_early = loss_fn(batch, *(logits_early, None)) + losses_early.append(loss_early) + + losses_early = torch.stack(losses_early, dim=0) + losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.module.model.text_decoder.layers), loss_scale_type, e_scale) return torch.sum(losses_scales * losses_early) From e138eae59ba2903a5f6a6b464562aef4f5bc678b Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 16 Jul 2024 04:42:58 +0000 Subject: [PATCH 07/19] fix exit hooks --- src/fairseq2/nn/transformer/decoder.py | 7 ++++++- src/fairseq2/nn/transformer/encoder.py | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index fe5c733e2..ba4baba3f 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -239,11 +239,16 @@ def forward( state_bag=state_bag, ) gpu_util.append(torch.cuda.utilization(torch.cuda.current_device())) - + + early_exit = False for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): + early_exit = True break + if early_exit: + break + if self.layer_norm is not None: seqs = self.layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index f0c174c1e..d6fba2f8c 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -197,10 +197,15 @@ def forward( seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask) gpu_util.append(torch.cuda.utilization(torch.cuda.current_device())) + early_exit = False for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): + early_exit = True break + if early_exit: + break + if self.layer_norm is not None: seqs = self.layer_norm(seqs) From 3e8cd675ad98961ce998e4144e8b4c954b01e788 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 27 Aug 2024 18:04:37 +0000 Subject: [PATCH 08/19] add profiling code for SamplingGenerator --- src/fairseq2/generation/sampling.py | 54 ++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 2081e8a89..b3749fe4f 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -8,10 +8,12 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Sequence, Tuple, Union, final +import time import torch from torch import Tensor from torch.nn.functional import softmax +import numpy as np from fairseq2.data import VocabularyInfo from fairseq2.generation.generator import ( @@ -259,10 +261,18 @@ def __call__( prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask], ) -> Seq2SeqGeneratorOutput: + seq_len = dict() + timer_result = dict() + # (P, S) - encoder_output, encoder_padding_mask = self.model.encode( + torch.cuda.synchronize() + start_time = time.time() + (encoder_output, encoder_padding_mask, gpu_util), src_seq_len = self.model.encode( source_seqs, source_padding_mask ) + torch.cuda.synchronize() + timer_result["Encoder"] = (time.time()-start_time)*1000 + seq_len["Encoder"] = [src_seq_len, encoder_output.shape[1], 1] if source_padding_mask is None: max_source_len = source_seqs.size(1) @@ -285,6 +295,8 @@ def __call__( f"`min_gen_len` must be less than or equal to `max_gen_len` ({max_gen_len}), but is {self.min_gen_len} instead. Adjust your `max_gen_len` argument." ) + torch.cuda.synchronize() + start_time = time.time() op = _SamplingSeq2SeqGeneratorOp( self.model, encoder_output, @@ -306,9 +318,13 @@ def __call__( self._step_hooks, ) - hypotheses = op() + hypotheses, decoding_step, gpu_util2 = op() + torch.cuda.synchronize() + timer_result["Decoder"] = (time.time()-start_time)*1000 - return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask) + seq_len["Decoder"] = [op.min_prompt_len-1] + seq_len["Decoder"] += [np.average([h[0].seq.shape[0] for h in hypotheses]), decoding_step] + return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask), timer_result, seq_len, np.average(gpu_util+gpu_util2) class Sampler(ABC): @@ -543,23 +559,31 @@ def __init__( self.output = [[] for _ in range(num_prompts)] def __call__(self) -> List[List[Hypothesis]]: - self._prepare_state() + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_utils.append(gpu_util) + decoding_step = 0 for self.step_nr in range(self.min_prompt_len, self.max_seq_len): - if not self._step(): + output, gpu_util = self._step() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: break + # breakpoint() if self.compute_scores: # Sort the hypotheses by their scores before returning. for hypotheses in self.output: hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] - return self.output + return self.output, decoding_step, gpu_utils def _prepare_state(self) -> None: # Fast-forward to the first step that needs to be generated. + gpu_util = [] if self.min_prompt_len > 1: - self._prefill() + gpu_util = self._prefill() # Fan out the state to `num_prompts` x `num_gens`. if self.num_gens > 1: @@ -573,10 +597,12 @@ def _prepare_state(self) -> None: self._reorder_state(fan_out) + return gpu_util + def _prefill(self) -> None: prefill_len = self.min_prompt_len - model_output = self._decode(self.seqs[:, : prefill_len - 1]) + model_output, gpu_util = self._decode(self.seqs[:, : prefill_len - 1]) self.state_bag.increment_step_nr(prefill_len - 1) @@ -610,9 +636,11 @@ def _prefill(self) -> None: for hook in self.step_hooks.values(): hook(self.prompt_indices, seqs, step_scores, prefill=True) + return gpu_util + def _step(self) -> bool: # Generate the next step output. - model_output = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr]) + model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr]) self.state_bag.increment_step_nr() @@ -713,13 +741,13 @@ def _step(self) -> bool: # No sequence left, we can return. if len(active_seq_indices) == 0: - return False + return False, gpu_util # Otherwise, remove the sequences that have reached EOS from the # state and continue generating the remaining ones. self._reorder_state(active_seq_indices) - return True + return True, gpu_util @abstractmethod def _decode(self, seqs: Tensor) -> SequenceModelOutput: @@ -894,7 +922,7 @@ def __init__( @override def _decode(self, seqs: Tensor) -> SequenceModelOutput: - decoder_output, decoder_padding_mask = self.model.decode( + decoder_output, decoder_padding_mask, gpu_util = self.model.decode( seqs, None, # We never use PAD in incremental decoding. self.encoder_output, @@ -902,7 +930,7 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: state_bag=self.state_bag, ) - return self.model.project(decoder_output, decoder_padding_mask) + return self.model.project(decoder_output, decoder_padding_mask), np.average(gpu_util) @override def _reorder_state(self, new_order: Tensor) -> None: From baebda3941e76fe14819aecabefc8bcf88d5555e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 27 Aug 2024 18:05:07 +0000 Subject: [PATCH 09/19] add other loss scaling types and make uniform default --- src/fairseq2/utils/early_exit_loss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/utils/early_exit_loss.py b/src/fairseq2/utils/early_exit_loss.py index 9ec0bef49..17caad34b 100644 --- a/src/fairseq2/utils/early_exit_loss.py +++ b/src/fairseq2/utils/early_exit_loss.py @@ -32,12 +32,14 @@ def hook( class LossScaleType(str, Enum): ONE = "one" L = "l" + L2 = "l2" SUM_L = "sum_l" + SUM_L2 = "sum_l2" INV_L = "inv_l" SQRT_L = "sqrt_l" INV_SQRT_L = "inv_sqrt_l" -def early_exit_loss(model, hidden_states_dict, batch, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): +def early_exit_loss(model, hidden_states_dict, batch, loss_fn, e_scale: float=20.0, loss_scale_type=LossScaleType.ONE): hidden_states = tuple(hidden_states_dict.values()) hidden_layer_ids = tuple(hidden_states_dict.keys()) @@ -59,9 +61,14 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType loss_scales = torch.ones(len(layer_ids)) case LossScaleType.L: loss_scales = torch.Tensor(layer_ids+1) + case LossScaleType.L2: + loss_scales = torch.Tensor((layer_ids+1)**2) case LossScaleType.SUM_L: # TODO: should we change to sum 0:i ? Perhaps create a new scale_type loss_scales = torch.cumsum(layer_ids+1, dim=0) + case LossScaleType.SUM_L2: + # TODO: should we change to sum 0:i ? Perhaps create a new scale_type + loss_scales = torch.cumsum((layer_ids+1)**2, dim=0) case LossScaleType.SQRT_L: loss_scales = torch.sqrt(layer_ids+1) case LossScaleType.INV_L: From dcaf66e5dbe8417f051f32c95c19e3ba99b34e6e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 29 Aug 2024 18:10:30 +0000 Subject: [PATCH 10/19] start speculative decoding --- src/fairseq2/generation/__init__.py | 2 + src/fairseq2/generation/sampling.py | 229 +++++++++++++++++++++++++++- 2 files changed, 228 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/generation/__init__.py b/src/fairseq2/generation/__init__.py index ec5a9faba..16b422343 100644 --- a/src/fairseq2/generation/__init__.py +++ b/src/fairseq2/generation/__init__.py @@ -30,9 +30,11 @@ from fairseq2.generation.sampling import Sampler as Sampler from fairseq2.generation.sampling import ( SamplingSeq2SeqGenerator as SamplingSeq2SeqGenerator, + SpeculativeSamplingSeq2SeqGenerator as SpeculativeSamplingSeq2SeqGenerator, ) from fairseq2.generation.sampling import ( SamplingSequenceGenerator as SamplingSequenceGenerator, + SpeculativeSamplingSequenceGenerator as SpeculativeSamplingSequenceGenerator, ) from fairseq2.generation.sampling import TopKSampler as TopKSampler from fairseq2.generation.sampling import TopPSampler as TopPSampler diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index b3749fe4f..a65067264 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -34,7 +34,6 @@ from fairseq2.typing import finaloverride, override -@final class SamplingSequenceGenerator(SequenceGenerator): """Represents a sequence generator based on sampling.""" @@ -162,7 +161,54 @@ def __call__( return SequenceGeneratorOutput(hypotheses) -@final +class SpeculativeSamplingSequenceGenerator(SamplingSequenceGenerator): + """Represents a speculative sequence generator based on sampling.""" + + def __init__( + self, + model: DecoderModel, + model_draft: DecoderModel, + *args, + **kwargs, + ) -> None: + """ + :param model: + The decoder model to use for generation. + :param model_draft: + The decoder model to use to generate draft tokens. + """ + super().__init__(model, *args, **kwargs) + self.model_draft = model_draft + + @torch.inference_mode() + def __call__( + self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] + ) -> SequenceGeneratorOutput: + op = _SpeculativeSamplingSequenceGeneratorOp( + self.model, + self.model_draft, + prompt_seqs, + prompt_padding_mask, + self.sampler, + self.num_gens, + self.min_gen_len, + self.max_gen_len, + self.max_seq_len, + self.echo_prompt, + self.compute_scores, + self.normalize_scores, + self.temperature, + self.unk_penalty, + self.len_penalty, + self.step_processors, + self._step_hooks, + ) + + hypotheses = op() + + return SequenceGeneratorOutput(hypotheses) + + class SamplingSeq2SeqGenerator(Seq2SeqGenerator): """Represents a sequence-to-sequence generator based on sampling.""" @@ -327,6 +373,103 @@ def __call__( return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask), timer_result, seq_len, np.average(gpu_util+gpu_util2) +class SpeculativeSamplingSeq2SeqGenerator(SamplingSeq2SeqGenerator): + """Represents a sequence-to-sequence generator based on sampling.""" + + model_draft: DecoderModel + + def __init__( + self, + model: EncoderDecoderModel, + model_draft: DecoderModel, + *args, + **kwargs, + ) -> None: + """ + :param model: + The encoder-decoder model to use for generation. + :param model_draft: + The decoder draft model to generate draft tokens. + """ + super().__init__(model, *args, **kwargs) + self.model_draft = model_draft + + @finaloverride + @torch.inference_mode() + def __call__( + self, + source_seqs: Tensor, + source_padding_mask: Optional[PaddingMask], + prompt_seqs: Tensor, + prompt_padding_mask: Optional[PaddingMask], + ) -> Seq2SeqGeneratorOutput: + seq_len = dict() + timer_result = dict() + + # (P, S) + torch.cuda.synchronize() + start_time = time.time() + (encoder_output, encoder_padding_mask, gpu_util), src_seq_len = self.model.encode( + source_seqs, source_padding_mask + ) + torch.cuda.synchronize() + timer_result["Encoder"] = (time.time()-start_time)*1000 + seq_len["Encoder"] = [src_seq_len, encoder_output.shape[1], 1] + + if source_padding_mask is None: + max_source_len = source_seqs.size(1) + else: + max_source_len = int(source_padding_mask.seq_lens.max()) + + a_term, b_term = self.max_gen_len + + # In seq2seq generation, the maximum generation length is relative to + # the source sequence length. + max_gen_len = int(a_term * max_source_len + b_term) + + if max_gen_len < 1: + raise ValueError( + f"`max_gen_len` must be greater than or equal to 1, but is {max_gen_len} instead. Adjust your `max_gen_len` argument." + ) + + if self.min_gen_len > max_gen_len: + raise ValueError( + f"`min_gen_len` must be less than or equal to `max_gen_len` ({max_gen_len}), but is {self.min_gen_len} instead. Adjust your `max_gen_len` argument." + ) + + torch.cuda.synchronize() + start_time = time.time() + op = _SpeculativeSamplingSeq2SeqGeneratorOp( + self.model, + self.model_draft, + encoder_output, + encoder_padding_mask, + prompt_seqs, + prompt_padding_mask, + self.sampler, + self.num_gens, + self.min_gen_len, + max_gen_len, + self.max_seq_len, + self.echo_prompt, + self.compute_scores, + self.normalize_scores, + self.temperature, + self.unk_penalty, + self.len_penalty, + self.step_processors, + self._step_hooks, + ) + + hypotheses, decoding_step, gpu_util2 = op() + torch.cuda.synchronize() + timer_result["Decoder"] = (time.time()-start_time)*1000 + + seq_len["Decoder"] = [op.min_prompt_len-1] + seq_len["Decoder"] += [np.average([h[0].seq.shape[0] for h in hypotheses]), decoding_step] + return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask), timer_result, seq_len, np.average(gpu_util+gpu_util2) + + class Sampler(ABC): """Represents a sampling algorithm.""" @@ -571,7 +714,6 @@ def __call__(self) -> List[List[Hypothesis]]: if not output: break - # breakpoint() if self.compute_scores: # Sort the hypotheses by their scores before returning. for hypotheses in self.output: @@ -870,6 +1012,62 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: return self.model.project(decoder_output, decoder_padding_mask) +class _SpeculativeSamplingSequenceGeneratorOp(_SamplingSequenceGeneratorOpBase): + model: DecoderModel + model_draft: DecoderModel + + def __init__( + self, + model: DecoderModel, + model_draft: DecoderModel, + prompt_seqs: Tensor, + prompt_padding_mask: Optional[PaddingMask], + sampler: Sampler, + num_gens: int, + min_gen_len: int, + max_gen_len: int, + max_seq_len: int, + echo_prompt: bool, + compute_scores: bool, + normalize_scores: bool, + temperature: float, + unk_penalty: float, + len_penalty: float, + step_processors: Sequence[StepProcessor], + step_hooks: Dict[int, StepHook], + ) -> None: + super().__init__( + prompt_seqs, + prompt_padding_mask, + sampler, + model.vocab_info, + num_gens, + min_gen_len, + max_gen_len, + max_seq_len, + echo_prompt, + compute_scores, + normalize_scores, + temperature, + unk_penalty, + len_penalty, + step_processors, + step_hooks, + ) + + self.model = model + self.model_draft = model_draft + + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: + decoder_output, decoder_padding_mask = self.model_draft.decode( + seqs, + None, # We never use PAD in incremental decoding. + state_bag=self.state_bag, + ) + + return self.model_draft.project(decoder_output, decoder_padding_mask) + + class _SamplingSeq2SeqGeneratorOp(_SamplingSequenceGeneratorOpBase): model: EncoderDecoderModel @@ -947,3 +1145,28 @@ def _reorder_state(self, new_order: Tensor) -> None: self.encoder_padding_mask = PaddingMask( encoder_seq_lens, batch_seq_len=self.encoder_output.size(1) ) + +class _SpeculativeSamplingSeq2SeqGeneratorOp(_SamplingSeq2SeqGeneratorOp): + model_draft: DecoderModel + + def __init__( + self, + model: EncoderDecoderModel, + model_draft: DecoderModel, + *args, + **kwargs, + ) -> None: + super().__init__(model, *args, **kwargs) + + self.model_draft = model_draft + + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: + decoder_output, decoder_padding_mask, gpu_util = self.model_draft.decode( + seqs, + None, # We never use PAD in incremental decoding. + self.encoder_output, + self.encoder_padding_mask, + state_bag=self.state_bag, + ) + + return self.model_draft.project(decoder_output, decoder_padding_mask), np.average(gpu_util) From 9c4f52bc9653cffe79c8abc2732aa9bc36964738 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 31 Aug 2024 03:19:20 +0000 Subject: [PATCH 11/19] end-to-end speculative decoding but not correct --- src/fairseq2/generation/sampling.py | 186 ++++++++++++++++++++++++++-- 1 file changed, 178 insertions(+), 8 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index a65067264..863ac1e94 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -376,12 +376,13 @@ def __call__( class SpeculativeSamplingSeq2SeqGenerator(SamplingSeq2SeqGenerator): """Represents a sequence-to-sequence generator based on sampling.""" - model_draft: DecoderModel + model_draft: EncoderDecoderModel def __init__( self, model: EncoderDecoderModel, - model_draft: DecoderModel, + model_draft: EncoderDecoderModel, + k_speculate: int, *args, **kwargs, ) -> None: @@ -389,10 +390,11 @@ def __init__( :param model: The encoder-decoder model to use for generation. :param model_draft: - The decoder draft model to generate draft tokens. + The encoder-decoder draft model to generate draft tokens. """ super().__init__(model, *args, **kwargs) self.model_draft = model_draft + self.k_speculate = k_speculate @finaloverride @torch.inference_mode() @@ -442,6 +444,7 @@ def __call__( op = _SpeculativeSamplingSeq2SeqGeneratorOp( self.model, self.model_draft, + self.k_speculate, encoder_output, encoder_padding_mask, prompt_seqs, @@ -1058,6 +1061,26 @@ def __init__( self.model = model self.model_draft = model_draft + def __call__(self) -> List[List[Hypothesis]]: + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_utils.append(gpu_util) + + decoding_step = 0 + for self.step_nr in range(self.min_prompt_len, self.max_seq_len): + output, gpu_util = self._step() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: + break + + if self.compute_scores: + # Sort the hypotheses by their scores before returning. + for hypotheses in self.output: + hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] + + return self.output, decoding_step, gpu_utils + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: decoder_output, decoder_padding_mask = self.model_draft.decode( seqs, @@ -1147,18 +1170,52 @@ def _reorder_state(self, new_order: Tensor) -> None: ) class _SpeculativeSamplingSeq2SeqGeneratorOp(_SamplingSeq2SeqGeneratorOp): - model_draft: DecoderModel + model_draft: EncoderDecoderModel + k_speculate: int + state_bag_draft: IncrementalStateBag def __init__( self, model: EncoderDecoderModel, - model_draft: DecoderModel, + model_draft: EncoderDecoderModel, + k_speculate: int, + encoder_output: Tensor, + encoder_padding_mask: Optional[PaddingMask], *args, **kwargs, ) -> None: - super().__init__(model, *args, **kwargs) - + _SamplingSeq2SeqGeneratorOp.__init__(self, model, encoder_output, encoder_padding_mask, *args, **kwargs) self.model_draft = model_draft + self.k_speculate = k_speculate + self.state_bag_draft = IncrementalStateBag(self.max_seq_len) + self.step_nr_draft = None + + def __call__(self) -> List[List[Hypothesis]]: + gpu_utils = [] + gpu_util = self._prepare_state() + gpu_utils.append(gpu_util) + + decoding_step = 0 + self.step_nr = self.min_prompt_len + while self.step_nr < self.max_seq_len: + for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): + output, gpu_util = self._step_draft() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: + break + + num_draft_tokens = self.step_nr_draft - self.step_nr + 1 + # TODO: change that + num_accepted_tokens = num_draft_tokens + self.step_nr += num_accepted_tokens + + if self.compute_scores: + # Sort the hypotheses by their scores before returning. + for hypotheses in self.output: + hypotheses.sort(key=lambda h: h.score, reverse=True) # type: ignore[arg-type, return-value] + + return self.output, decoding_step, gpu_utils def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: decoder_output, decoder_padding_mask, gpu_util = self.model_draft.decode( @@ -1166,7 +1223,120 @@ def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: None, # We never use PAD in incremental decoding. self.encoder_output, self.encoder_padding_mask, - state_bag=self.state_bag, + state_bag=self.state_bag_draft, ) return self.model_draft.project(decoder_output, decoder_padding_mask), np.average(gpu_util) + + def _step_draft(self) -> bool: + # Generate the next step output. + model_output, gpu_util = self._decode_draft(self.seqs[:, self.step_nr_draft - 1 : self.step_nr_draft]) + + self.state_bag_draft.increment_step_nr() + + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # (N, 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # (N, 1, V) -> (N, V) + probs.squeeze_(1) + + # If we are generating the last possible step, force it to be EOS + # regardless of its score. + if self.step_nr_draft == self.max_seq_len - 1: + batch_size = self.seqs.size(0) + + # (N) + vocab_indices = self.seqs.new_full((batch_size,), self.eos_idx) + else: + # Process `probs` in-place if requested. + for processor in self.step_processors: + processor(self.seqs[:, : self.step_nr_draft], probs) + + # Apply UNK penalty. + if self.unk_idx is not None: + probs[:, self.unk_idx] -= self.unk_penalty + + # Never allow PAD. + if self.pad_idx is not None: + probs[:, self.pad_idx] = 0 + + # Do not allow EOS till we reach the minimum sequence length. + if self.step_nr_draft < self.min_seq_len - 1: + probs[:, self.eos_idx] = 0 + + # (N) + vocab_indices = self.sampler(probs) + + # EOS mask of the current step. + # (N) + eos_mask = vocab_indices == self.eos_idx + + # Ignore the generated indices for the prompt sequences. + if self.step_nr_draft < self.max_prompt_len: + assert self.prompt_mask is not None + + # (N) + mask = self.prompt_mask[:, self.step_nr_draft] + + # Override the generated indices. + vocab_indices[mask] = self.seqs[mask, self.step_nr_draft] + + # Ignore EOS in the prompt sequences. + eos_mask[mask] = False + else: + self.prompt_mask = None # Not needed anymore, release. + + # Record the current step. + self.seqs[:, self.step_nr_draft] = vocab_indices + + if self.step_scores is not None: + # (N, 1) + scores = torch.gather(probs, dim=-1, index=vocab_indices[:, None]) + + # Record the scores of the current step. + self.step_scores[:, self.step_nr_draft] = scores.squeeze(1) + + if self.step_hooks: + seqs = self.seqs[:, : self.step_nr_draft + 1] + + if self.step_scores is None: + step_scores = None + else: + step_scores = self.step_scores[:, : self.step_nr_draft + 1] + + for hook in self.step_hooks.values(): + hook(self.prompt_indices, seqs, step_scores, prefill=False) + + # Retrieve the indices of the sequences that have reached EOS. + # (F, 1) + eos_seq_indices = eos_mask.nonzero() + + # If one or more sequences have reached EOS, move them to the output and + # continue generating the remaining sequences. + if len(eos_seq_indices) > 0: + # Move the sequences that have reached EOS to the output. + for seq_idx in eos_seq_indices: + self._finish_sequence(int(seq_idx)) + + # (N) + active_seq_mask = ~eos_mask + + # (N - F, 1) -> (N - F) + active_seq_indices = active_seq_mask.nonzero().squeeze(-1) + + # No sequence left, we can return. + if len(active_seq_indices) == 0: + return False, gpu_util + + # Otherwise, remove the sequences that have reached EOS from the + # state and continue generating the remaining ones. + ## TODO: move that on main model only? + self._reorder_state(active_seq_indices) + self.state_bag_draft.reorder(active_seq_indices) + + return True, gpu_util From 3721aca4356d012bcc2e15e98e2287e2f401aef6 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 31 Aug 2024 12:28:29 +0000 Subject: [PATCH 12/19] create functions for prefill for draft model --- src/fairseq2/generation/sampling.py | 159 +++++++++++++++++++++++++--- 1 file changed, 144 insertions(+), 15 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 863ac1e94..6061df858 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1064,6 +1064,7 @@ def __init__( def __call__(self) -> List[List[Hypothesis]]: gpu_utils = [] gpu_util = self._prepare_state() + gpu_util = self._prepare_state_draft() gpu_utils.append(gpu_util) decoding_step = 0 @@ -1081,6 +1082,65 @@ def __call__(self) -> List[List[Hypothesis]]: return self.output, decoding_step, gpu_utils + def _prepare_state_draft(self) -> None: + # Fast-forward to the first step that needs to be generated. + gpu_util = [] + if self.min_prompt_len > 1: + gpu_util = self._prefill_draft() + + # Fan out the state to `num_prompts` x `num_gens`. + if self.num_gens > 1: + num_prompts = self.seqs.size(0) + + # (P) + fan_out = torch.arange(num_prompts, device=self.seqs.device) + + # (P) -> (P x G) + fan_out = repeat_interleave(fan_out, dim=0, repeat=self.num_gens) + + self._reorder_state(fan_out) + + return gpu_util + + def _prefill_draft(self) -> None: + prefill_len = self.min_prompt_len + + model_output, gpu_util = self._decode_draft(self.seqs[:, : prefill_len - 1]) + + self.state_bag_draft.increment_step_nr(prefill_len - 1) + + if self.step_scores is not None: + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, 1:prefill_len].unsqueeze(-1) + ) + + # Bootstrap the step scores. + # (P, S_prm - 1) + self.step_scores[:, 1:prefill_len] = prompt_scores.squeeze(-1) + + if self.step_hooks: + seqs = self.seqs[:, :prefill_len] + + if self.step_scores is None: + step_scores = None + else: + step_scores = self.step_scores[:, :prefill_len] + + for hook in self.step_hooks.values(): + hook(self.prompt_indices, seqs, step_scores, prefill=True) + + return gpu_util + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: decoder_output, decoder_padding_mask = self.model_draft.decode( seqs, @@ -1193,22 +1253,32 @@ def __init__( def __call__(self) -> List[List[Hypothesis]]: gpu_utils = [] gpu_util = self._prepare_state() + gpu_util = self._prepare_state_draft() gpu_utils.append(gpu_util) decoding_step = 0 - self.step_nr = self.min_prompt_len - while self.step_nr < self.max_seq_len: - for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): - output, gpu_util = self._step_draft() - gpu_utils.append(gpu_util) - decoding_step+=1 - if not output: - break - - num_draft_tokens = self.step_nr_draft - self.step_nr + 1 - # TODO: change that - num_accepted_tokens = num_draft_tokens - self.step_nr += num_accepted_tokens + # self.step_nr = self.min_prompt_len + # while self.step_nr < self.max_seq_len: + # for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): + # output, gpu_util = self._step_draft() + # gpu_utils.append(gpu_util) + # decoding_step+=1 + # if not output: + # break + + # num_draft_tokens = self.step_nr_draft - self.step_nr + 1 + # # TODO: change that + # num_accepted_tokens = num_draft_tokens + # self.step_nr += num_accepted_tokens + + # for testing + for self.step_nr in range(self.min_prompt_len, self.max_seq_len): + self.step_nr_draft = self.step_nr + output, gpu_util = self._step_draft() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: + break if self.compute_scores: # Sort the hypotheses by their scores before returning. @@ -1217,6 +1287,65 @@ def __call__(self) -> List[List[Hypothesis]]: return self.output, decoding_step, gpu_utils + def _prepare_state_draft(self) -> None: + # Fast-forward to the first step that needs to be generated. + gpu_util = [] + if self.min_prompt_len > 1: + gpu_util = self._prefill_draft() + + # Fan out the state to `num_prompts` x `num_gens`. + if self.num_gens > 1: + num_prompts = self.seqs.size(0) + + # (P) + fan_out = torch.arange(num_prompts, device=self.seqs.device) + + # (P) -> (P x G) + fan_out = repeat_interleave(fan_out, dim=0, repeat=self.num_gens) + + self._reorder_state(fan_out) + + return gpu_util + + def _prefill_draft(self) -> None: + prefill_len = self.min_prompt_len + + model_output, gpu_util = self._decode_draft(self.seqs[:, : prefill_len - 1]) + + self.state_bag_draft.increment_step_nr(prefill_len - 1) + + if self.step_scores is not None: + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, 1:prefill_len].unsqueeze(-1) + ) + + # Bootstrap the step scores. + # (P, S_prm - 1) + self.step_scores[:, 1:prefill_len] = prompt_scores.squeeze(-1) + + if self.step_hooks: + seqs = self.seqs[:, :prefill_len] + + if self.step_scores is None: + step_scores = None + else: + step_scores = self.step_scores[:, :prefill_len] + + for hook in self.step_hooks.values(): + hook(self.prompt_indices, seqs, step_scores, prefill=True) + + return gpu_util + def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: decoder_output, decoder_padding_mask, gpu_util = self.model_draft.decode( seqs, @@ -1336,7 +1465,7 @@ def _step_draft(self) -> bool: # Otherwise, remove the sequences that have reached EOS from the # state and continue generating the remaining ones. ## TODO: move that on main model only? - self._reorder_state(active_seq_indices) - self.state_bag_draft.reorder(active_seq_indices) + # self._reorder_state(active_seq_indices) + # self.state_bag_draft.reorder(active_seq_indices) return True, gpu_util From 5a2e0215ddbe45d71a314e8a36449620e44d425b Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 31 Aug 2024 12:38:00 +0000 Subject: [PATCH 13/19] cleanup and get draft model to actually decode --- src/fairseq2/generation/sampling.py | 35 +++++++++++------------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 6061df858..55bd88b6d 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1257,28 +1257,19 @@ def __call__(self) -> List[List[Hypothesis]]: gpu_utils.append(gpu_util) decoding_step = 0 - # self.step_nr = self.min_prompt_len - # while self.step_nr < self.max_seq_len: - # for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): - # output, gpu_util = self._step_draft() - # gpu_utils.append(gpu_util) - # decoding_step+=1 - # if not output: - # break - - # num_draft_tokens = self.step_nr_draft - self.step_nr + 1 - # # TODO: change that - # num_accepted_tokens = num_draft_tokens - # self.step_nr += num_accepted_tokens - - # for testing - for self.step_nr in range(self.min_prompt_len, self.max_seq_len): - self.step_nr_draft = self.step_nr - output, gpu_util = self._step_draft() - gpu_utils.append(gpu_util) - decoding_step+=1 - if not output: - break + self.step_nr = self.min_prompt_len + while self.step_nr < self.max_seq_len: + for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): + output, gpu_util = self._step_draft() + gpu_utils.append(gpu_util) + decoding_step+=1 + if not output: + break + + num_draft_tokens = self.step_nr_draft - self.step_nr + 1 + # TODO: change that + num_accepted_tokens = num_draft_tokens + self.step_nr += num_accepted_tokens if self.compute_scores: # Sort the hypotheses by their scores before returning. From 8aa1a91e642c883086d75ac1e4131407b9fe687e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 31 Aug 2024 19:30:30 +0000 Subject: [PATCH 14/19] verify k_speculate tokens --- src/fairseq2/generation/sampling.py | 43 ++++++++++++++++++- .../nn/transformer/multihead_attention.py | 8 ++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 55bd88b6d..bf62613e5 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1259,6 +1259,7 @@ def __call__(self) -> List[List[Hypothesis]]: decoding_step = 0 self.step_nr = self.min_prompt_len while self.step_nr < self.max_seq_len: + # Run draft model to generate draft tokens for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): output, gpu_util = self._step_draft() gpu_utils.append(gpu_util) @@ -1267,10 +1268,20 @@ def __call__(self) -> List[List[Hypothesis]]: break num_draft_tokens = self.step_nr_draft - self.step_nr + 1 - # TODO: change that + + # Run draft tokens by main model + # TODO: receive probabilities or tokens from main + self._verify_main() + + # Count how many draft tokens are accepted by main model + # TODO: run verification rather than hard coding number of accepted tokens num_accepted_tokens = num_draft_tokens self.step_nr += num_accepted_tokens + # TODO: move post decoding processing here (e.g., self.step_scores, hooks, etc.) + self.state_bag.increment_step_nr(num_accepted_tokens) + self.state_bag_draft.increment_step_nr(- (num_draft_tokens - num_accepted_tokens)) + if self.compute_scores: # Sort the hypotheses by their scores before returning. for hypotheses in self.output: @@ -1298,6 +1309,32 @@ def _prepare_state_draft(self) -> None: return gpu_util + def _verify_main(self) -> None: + model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr_draft]) + + # TODO: move to post-acceptance? + if self.step_scores is not None: + logits = model_output.logits + + if self.temperature != 1.0: + logits /= self.temperature + + # TODO: return probs regardless of self.step_scores? + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) + + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, self.step_nr+1: self.step_nr_draft+1].unsqueeze(-1) + ) + + # Bootstrap the step scores. + # (P, S_prm - 1) + self.step_scores[:, self.step_nr+1: self.step_nr_draft+1] = prompt_scores.squeeze(-1) + + return gpu_util + def _prefill_draft(self) -> None: prefill_len = self.min_prompt_len @@ -1414,6 +1451,7 @@ def _step_draft(self) -> bool: # Record the current step. self.seqs[:, self.step_nr_draft] = vocab_indices + # TODO: move to apply on accepted tokens? if self.step_scores is not None: # (N, 1) scores = torch.gather(probs, dim=-1, index=vocab_indices[:, None]) @@ -1421,6 +1459,7 @@ def _step_draft(self) -> bool: # Record the scores of the current step. self.step_scores[:, self.step_nr_draft] = scores.squeeze(1) + # TODO: move to apply on accepted tokens? if self.step_hooks: seqs = self.seqs[:, : self.step_nr_draft + 1] @@ -1432,10 +1471,12 @@ def _step_draft(self) -> bool: for hook in self.step_hooks.values(): hook(self.prompt_indices, seqs, step_scores, prefill=False) + # TODO: move to apply on accepted tokens? # Retrieve the indices of the sequences that have reached EOS. # (F, 1) eos_seq_indices = eos_mask.nonzero() + # TODO: move to apply on accepted tokens? # If one or more sequences have reached EOS, move them to the output and # continue generating the remaining sequences. if len(eos_seq_indices) > 0: diff --git a/src/fairseq2/nn/transformer/multihead_attention.py b/src/fairseq2/nn/transformer/multihead_attention.py index 43e2ac878..e02924490 100644 --- a/src/fairseq2/nn/transformer/multihead_attention.py +++ b/src/fairseq2/nn/transformer/multihead_attention.py @@ -632,11 +632,13 @@ def __init__(self, k: Tensor, v: Tensor, max_seq_len: int) -> None: @finaloverride def append(self, k: Tensor, v: Tensor) -> None: pos = self.seq_len + assert k.shape[2] == v.shape[2] + append_len = k.shape[2] - self.k[:, :, pos : pos + 1] = k - self.v[:, :, pos : pos + 1] = v + self.k[:, :, pos : pos + append_len] = k + self.v[:, :, pos : pos + append_len] = v - self.seq_len += 1 + self.seq_len += append_len @finaloverride def get(self) -> Tuple[Tensor, Tensor]: From 5480af5a5bad3b401ba61097c37684f25a5106d7 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Sep 2024 12:54:21 +0000 Subject: [PATCH 15/19] remove scores calculation for draft prefill --- src/fairseq2/generation/sampling.py | 30 ----------------------------- 1 file changed, 30 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index bf62613e5..2bdabeb4c 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1342,36 +1342,6 @@ def _prefill_draft(self) -> None: self.state_bag_draft.increment_step_nr(prefill_len - 1) - if self.step_scores is not None: - logits = model_output.logits - - if self.temperature != 1.0: - logits /= self.temperature - - # (P, S_prm - 1, V) - probs = softmax(logits, dim=-1, dtype=torch.float32) - - # Fetch the scores of the next prompt step. - # (P, S_prm - 1, 1) - prompt_scores = torch.gather( - probs, dim=-1, index=self.seqs[:, 1:prefill_len].unsqueeze(-1) - ) - - # Bootstrap the step scores. - # (P, S_prm - 1) - self.step_scores[:, 1:prefill_len] = prompt_scores.squeeze(-1) - - if self.step_hooks: - seqs = self.seqs[:, :prefill_len] - - if self.step_scores is None: - step_scores = None - else: - step_scores = self.step_scores[:, :prefill_len] - - for hook in self.step_hooks.values(): - hook(self.prompt_indices, seqs, step_scores, prefill=True) - return gpu_util def _decode_draft(self, seqs: Tensor) -> SequenceModelOutput: From 8331df283e97851a79ed6a740555207f09e43720 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Sep 2024 14:33:03 +0000 Subject: [PATCH 16/19] make verification step return vocab --- src/fairseq2/generation/sampling.py | 82 ++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 2bdabeb4c..63c119a25 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1271,7 +1271,8 @@ def __call__(self) -> List[List[Hypothesis]]: # Run draft tokens by main model # TODO: receive probabilities or tokens from main - self._verify_main() + probs, vocab_indices, gpu_util = self._verify_main() + gpu_utils.append(gpu_util) # Count how many draft tokens are accepted by main model # TODO: run verification rather than hard coding number of accepted tokens @@ -1312,28 +1313,68 @@ def _prepare_state_draft(self) -> None: def _verify_main(self) -> None: model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr_draft]) - # TODO: move to post-acceptance? - if self.step_scores is not None: - logits = model_output.logits + logits = model_output.logits - if self.temperature != 1.0: - logits /= self.temperature + if self.temperature != 1.0: + logits /= self.temperature - # TODO: return probs regardless of self.step_scores? - # (P, S_prm - 1, V) - probs = softmax(logits, dim=-1, dtype=torch.float32) + # TODO: return probs regardless of self.step_scores? + # (P, S_prm - 1, V) + probs = softmax(logits, dim=-1, dtype=torch.float32) - # Fetch the scores of the next prompt step. - # (P, S_prm - 1, 1) - prompt_scores = torch.gather( - probs, dim=-1, index=self.seqs[:, self.step_nr+1: self.step_nr_draft+1].unsqueeze(-1) - ) + # Fetch the scores of the next prompt step. + # (P, S_prm - 1, 1) + prompt_scores = torch.gather( + probs, dim=-1, index=self.seqs[:, self.step_nr+1: self.step_nr_draft+1].unsqueeze(-1) + ) - # Bootstrap the step scores. - # (P, S_prm - 1) - self.step_scores[:, self.step_nr+1: self.step_nr_draft+1] = prompt_scores.squeeze(-1) + # Copied from self._draft_step() + # If we are generating the last possible step, force it to be EOS + # regardless of its score. + # TODO: move this to after verification? + # Process `probs` in-place if requested. + for processor in self.step_processors: + processor(self.seqs[:, : self.step_nr_draft], probs) - return gpu_util + # Apply UNK penalty. + if self.unk_idx is not None: + probs[:, self.unk_idx] -= self.unk_penalty + + # Never allow PAD. + if self.pad_idx is not None: + probs[:, self.pad_idx] = 0 + + # Do not allow EOS till we reach the minimum sequence length. + if self.step_nr_draft < self.min_seq_len - 1: + probs[:, self.eos_idx] = 0 + + # (N) + vocab_indices = self.sampler(probs) + + # EOS mask of the current step. + # (N) + eos_mask = vocab_indices == self.eos_idx + + # Ignore the generated indices for the prompt sequences. + if self.step_nr_draft < self.max_prompt_len: + assert self.prompt_mask is not None + + # (N) + mask = self.prompt_mask[:, self.step_nr_draft] + + # Override the generated indices. + vocab_indices[mask] = self.seqs[mask, self.step_nr_draft] + + # Ignore EOS in the prompt sequences. + eos_mask[mask] = False + else: + self.prompt_mask = None # Not needed anymore, release. + + # TODO: commenting for now. Move to accept step. + ## Record the current step. + # self.seqs[:, self.step_nr - 1 : self.step_nr_draft] = vocab_indices + + return probs, vocab_indices, gpu_util def _prefill_draft(self) -> None: prefill_len = self.min_prompt_len @@ -1460,9 +1501,10 @@ def _step_draft(self) -> bool: # (N - F, 1) -> (N - F) active_seq_indices = active_seq_mask.nonzero().squeeze(-1) + # TODO: commenting for now as we will move this to post-acceptance logic # No sequence left, we can return. - if len(active_seq_indices) == 0: - return False, gpu_util + # if len(active_seq_indices) == 0: + # return False, gpu_util # Otherwise, remove the sequences that have reached EOS from the # state and continue generating the remaining ones. From a844e9975b99a47d51e3b980663bd5f94524115f Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Sep 2024 18:04:12 +0000 Subject: [PATCH 17/19] calculate q p as per gpt-fast speculative decoding algorithm --- src/fairseq2/generation/sampling.py | 34 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 63c119a25..3110113a5 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1258,24 +1258,37 @@ def __call__(self) -> List[List[Hypothesis]]: decoding_step = 0 self.step_nr = self.min_prompt_len + bs = self.seqs.shape[0] + vocab_indices_draft = torch.zeros(bs, self.k_speculate, device=self.seqs.device, dtype=self.seqs.dtype) + probs_draft = torch.zeros(bs, self.k_speculate, self.model.final_proj.output_dim, device=self.seqs.device, dtype=self.encoder_output.dtype) + output_draft = [False] * self.k_speculate while self.step_nr < self.max_seq_len: # Run draft model to generate draft tokens - for self.step_nr_draft in range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len)): - output, gpu_util = self._step_draft() - gpu_utils.append(gpu_util) + for idx_draft, self.step_nr_draft in enumerate(range(self.step_nr, min(self.step_nr + self.k_speculate, self.max_seq_len))): + output_draft[idx_draft], probs_draft[:, idx_draft ,:], vocab_indices_draft[:, idx_draft], gpu_util_draft = self._step_draft() + gpu_utils.append(gpu_util_draft) decoding_step+=1 - if not output: - break + # TODO: Move this to post-acceptance logic + # if not output_draft: + # break num_draft_tokens = self.step_nr_draft - self.step_nr + 1 # Run draft tokens by main model # TODO: receive probabilities or tokens from main - probs, vocab_indices, gpu_util = self._verify_main() - gpu_utils.append(gpu_util) + probs_main, vocab_indices_main, gpu_util_main = self._verify_main() + gpu_utils.append(gpu_util_main) # Count how many draft tokens are accepted by main model # TODO: run verification rather than hard coding number of accepted tokens + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = probs_draft[:, torch.arange(0, num_draft_tokens, device=probs_draft.device), vocab_indices_draft[:, :num_draft_tokens]] + # print(f"probs_main.shape: {probs_main.shape}, torch.arange(0, num_draft_tokens, device=probs_draft.device): {torch.arange(0, num_draft_tokens, device=probs_draft.device)}, vocab_indices_draft: {vocab_indices_draft[:, :num_draft_tokens]}") + q = probs_main[:, torch.arange(0, num_draft_tokens, device=probs_main.device), vocab_indices_draft[:, :num_draft_tokens]] + accept_draft_prob = torch.minimum(torch.ones(()), q[:, :, :num_draft_tokens]/ p) + rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() num_accepted_tokens = num_draft_tokens self.step_nr += num_accepted_tokens @@ -1501,10 +1514,9 @@ def _step_draft(self) -> bool: # (N - F, 1) -> (N - F) active_seq_indices = active_seq_mask.nonzero().squeeze(-1) - # TODO: commenting for now as we will move this to post-acceptance logic # No sequence left, we can return. - # if len(active_seq_indices) == 0: - # return False, gpu_util + if len(active_seq_indices) == 0: + return False, probs, vocab_indices, gpu_util # Otherwise, remove the sequences that have reached EOS from the # state and continue generating the remaining ones. @@ -1512,4 +1524,4 @@ def _step_draft(self) -> bool: # self._reorder_state(active_seq_indices) # self.state_bag_draft.reorder(active_seq_indices) - return True, gpu_util + return True, probs, vocab_indices, gpu_util From 4ac712b6d989598a4da2cef75ae2db5eea237848 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Sep 2024 18:52:03 +0000 Subject: [PATCH 18/19] fix shapes of tensors --- src/fairseq2/generation/sampling.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index 3110113a5..df6764920 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1284,11 +1284,13 @@ def __call__(self) -> List[List[Hypothesis]]: # q: target prob, p: draft prob # q >= p: always accept draft token # q < p: q/p prob to accept draft token - p = probs_draft[:, torch.arange(0, num_draft_tokens, device=probs_draft.device), vocab_indices_draft[:, :num_draft_tokens]] - # print(f"probs_main.shape: {probs_main.shape}, torch.arange(0, num_draft_tokens, device=probs_draft.device): {torch.arange(0, num_draft_tokens, device=probs_draft.device)}, vocab_indices_draft: {vocab_indices_draft[:, :num_draft_tokens]}") - q = probs_main[:, torch.arange(0, num_draft_tokens, device=probs_main.device), vocab_indices_draft[:, :num_draft_tokens]] - accept_draft_prob = torch.minimum(torch.ones(()), q[:, :, :num_draft_tokens]/ p) - rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() + p = probs_draft[:, torch.arange(0, num_draft_tokens, device=probs_draft.device), :] + p = torch.gather(p, dim=2, index=vocab_indices_draft[:, :num_draft_tokens].unsqueeze(-1)) + q = probs_main[:, torch.arange(0, num_draft_tokens, device=probs_main.device), :] + q = torch.gather(q, dim=2, index=vocab_indices_main[:, :num_draft_tokens].unsqueeze(-1)) + accept_draft_prob = torch.minimum(torch.ones(()), q[:, :, :num_draft_tokens]/ p).squeeze(dim=-1) + is_accepted = torch.rand_like(accept_draft_prob) > accept_draft_prob + rejected_locations = torch.where(is_accepted == False)[1] num_accepted_tokens = num_draft_tokens self.step_nr += num_accepted_tokens From e6ce2e722d97dd83bb67003b9b0989a3715094f2 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 4 Sep 2024 17:05:44 +0000 Subject: [PATCH 19/19] add more changes --- src/fairseq2/generation/sampling.py | 30 +++++++++++++++++++++++----- src/fairseq2/nn/incremental_state.py | 5 +++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling.py index df6764920..39f2389aa 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling.py @@ -1290,13 +1290,32 @@ def __call__(self) -> List[List[Hypothesis]]: q = torch.gather(q, dim=2, index=vocab_indices_main[:, :num_draft_tokens].unsqueeze(-1)) accept_draft_prob = torch.minimum(torch.ones(()), q[:, :, :num_draft_tokens]/ p).squeeze(dim=-1) is_accepted = torch.rand_like(accept_draft_prob) > accept_draft_prob - rejected_locations = torch.where(is_accepted == False)[1] - num_accepted_tokens = num_draft_tokens - self.step_nr += num_accepted_tokens + + rejected_locations = torch.argmax((~is_accepted).int(), dim=-1) + no_false = torch.all(is_accepted, dim=1) + rejected_locations[no_false] = -1 + min_rejected_locations = torch.min(rejected_locations) + + if min_rejected_locations == -1: # All draft tokens have been accepted + # TODO: Remove Hack + min_rejected_locations = 0 + + if True: + accept_length = min_rejected_locations + p = probs_draft[:, accept_length, :] + q = probs_main[:, accept_length, :] + probs_new = q - p + probs_new = torch.where(probs_new > 0, probs_new, 0.0) + probs_new = probs_new / probs_new.sum() + # (N) + vocab_indices = self.sampler(probs_new) + self.seqs[:, self.step_nr + accept_length] = vocab_indices + + self.step_nr += accept_length+1 # TODO: move post decoding processing here (e.g., self.step_scores, hooks, etc.) - self.state_bag.increment_step_nr(num_accepted_tokens) - self.state_bag_draft.increment_step_nr(- (num_draft_tokens - num_accepted_tokens)) + self.state_bag.increment_step_nr(- (num_draft_tokens - accept_length - 1)) + self.state_bag_draft.increment_step_nr(- (num_draft_tokens - accept_length - 1)) if self.compute_scores: # Sort the hypotheses by their scores before returning. @@ -1327,6 +1346,7 @@ def _prepare_state_draft(self) -> None: def _verify_main(self) -> None: model_output, gpu_util = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr_draft]) + self.state_bag.increment_step_nr(self.step_nr_draft - (self.step_nr - 1)) logits = model_output.logits diff --git a/src/fairseq2/nn/incremental_state.py b/src/fairseq2/nn/incremental_state.py index 935219d74..13f336553 100644 --- a/src/fairseq2/nn/incremental_state.py +++ b/src/fairseq2/nn/incremental_state.py @@ -74,6 +74,11 @@ def increment_step_nr(self, value: int = 1) -> None: self.step_nr = step_nr + # Update seq_len of attention modules + for module in self._module_states.values(): + if hasattr(module, "seq_len"): + module.seq_len = step_nr + def get_state(self, m: Module, kls: Type[T]) -> Optional[T]: """Get the state of ``m`` if present in the bag.