Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interleaving sliding window for Ministral-8B-Instruct-2410 #10591

Merged
merged 6 commits into from
Nov 30, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,24 @@
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)

layer_idx: int = int(prefix.split(".")[0])
if isinstance(config.interleaved_sliding_window, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check hasattr(config, "interleaved_sliding_window")

sliding_window = config.interleaved_sliding_window
elif isinstance(config.interleaved_sliding_window, list):
sw_idx = layer_idx % len(sliding_window)
sliding_window = config.interleaved_sliding_window[sw_idx]
else:
None

Check failure on line 178 in vllm/model_executor/models/llama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (B018)

vllm/model_executor/models/llama.py:178:13: B018 Found useless expression. Either assign it to a variable or remove it.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
)

Expand Down
Loading