Skip to content

Commit

Permalink
fix comment and typo
Browse files Browse the repository at this point in the history
Signed-off-by: Sungjae Lee <[email protected]>
  • Loading branch information
llsj14 committed Jan 3, 2025
1 parent aa183ff commit f2751c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, x):
return x


class DummOutputNorm(nn.Module):
class DummyOutputNorm(nn.Module):

def forward(self, x, residual):
x = x + residual
Expand All @@ -36,8 +36,9 @@ class EAGLE(nn.Module):
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
Following this approach, our implementation also disables
the input_layernormfor the first decoder layer.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Modify layer normalization and residual connections as suggested
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
self.model.model.norm = DummOutputNorm()
self.model.model.norm = DummyOutputNorm()

self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool:
if self._rank != 0:
return False

return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
return now - self._last_metrics_collect_time >= 0.1 # noqa: E501

def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
Expand Down

0 comments on commit f2751c8

Please sign in to comment.