Skip to content

Commit

Permalink
feat(model): support llama3_1
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu committed Jul 24, 2024
1 parent 9264e4b commit a1594d7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
21 changes: 21 additions & 0 deletions TransformerLens/transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@
"CodeLlama-7b-Python-hf",
"CodeLlama-7b-Instruct-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
"Baidicoot/Othello-GPT-Transformer-Lens",
Expand Down Expand Up @@ -809,6 +811,25 @@ def convert_hf_model_config(model_name: str, **kwargs):
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3.1-8B" in official_model_name:
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 14336,
"n_layers": 32,
"n_ctx": 8192,
"eps": 1e-5,
"d_vocab": 128256,
"act_fn": "silu",
"n_key_value_heads": 8,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3-70B" in official_model_name:
cfg_dict = {
"d_model": 8192,
Expand Down
10 changes: 5 additions & 5 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ authors = [
]
dependencies = [
"datasets>=2.17.0",
"transformers>=4.43.0",
"einops>=0.7.0",
"fastapi>=0.110.0",
"matplotlib>=3.8.3",
Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def from_initialization_searching(
cfg: LanguageModelSAETrainingConfig,
):
test_batch = activation_store.next(
batch_size=cfg.train_batch_size * 8
) # just random hard code xd
batch_size=cfg.train_batch_size
)
activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore

if (
Expand Down

0 comments on commit a1594d7

Please sign in to comment.