diff --git a/TransformerLens/transformer_lens/loading_from_pretrained.py b/TransformerLens/transformer_lens/loading_from_pretrained.py index 49c678d..a200e73 100644 --- a/TransformerLens/transformer_lens/loading_from_pretrained.py +++ b/TransformerLens/transformer_lens/loading_from_pretrained.py @@ -122,7 +122,9 @@ "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3-70B", "meta-llama/Meta-Llama-3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", @@ -809,6 +811,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif "Meta-Llama-3.1-8B" in official_model_name: + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 14336, + "n_layers": 32, + "n_ctx": 8192, + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } elif "Meta-Llama-3-70B" in official_model_name: cfg_dict = { "d_model": 8192, diff --git a/pdm.lock b/pdm.lock index 5b12cf3..d18e932 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:912a3ded5baf368f138e3f3b1ce24003ffffcf59a027a164404f5448a03733ea" +content_hash = "sha256:5c71418ab629971629ad4f44f415d2e9f8dddf153f32c46bc11770d739427917" [[package]] name = "accelerate" @@ -2509,13 +2509,13 @@ dependencies = [ [[package]] name = "transformers" -version = "4.41.2" +version = "4.43.1" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default", "dev"] dependencies = [ "filelock", - "huggingface-hub<1.0,>=0.23.0", + "huggingface-hub<1.0,>=0.23.2", "numpy>=1.17", "packaging>=20.0", "pyyaml>=5.1", @@ -2526,8 +2526,8 @@ dependencies = [ "tqdm>=4.27", ] files = [ - {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, - {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, + {file = "transformers-4.43.1-py3-none-any.whl", hash = "sha256:eb44b731902e062acbaff196ae4896d7cb3494ddf38275aa00a5fcfb5b34f17d"}, + {file = "transformers-4.43.1.tar.gz", hash = "sha256:662252c4d0e31b6684f68f68d5cc8206dd7f83da80eb3235be3dc5b3c9fdbdbd"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index b7d8f01..efadea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ authors = [ ] dependencies = [ "datasets>=2.17.0", + "transformers>=4.43.0", "einops>=0.7.0", "fastapi>=0.110.0", "matplotlib>=3.8.3", diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 416ae46..df7f8a7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -624,8 +624,8 @@ def from_initialization_searching( cfg: LanguageModelSAETrainingConfig, ): test_batch = activation_store.next( - batch_size=cfg.train_batch_size * 8 - ) # just random hard code xd + batch_size=cfg.train_batch_size + ) activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore if (