From a1594d78ce4567ced66ab886ecff86ff3b6f7d08 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 24 Jul 2024 19:13:57 +0800 Subject: [PATCH 1/3] feat(model): support llama3_1 --- .../loading_from_pretrained.py | 21 +++++++++++++++++++ pdm.lock | 10 ++++----- pyproject.toml | 1 + src/lm_saes/sae.py | 4 ++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/TransformerLens/transformer_lens/loading_from_pretrained.py b/TransformerLens/transformer_lens/loading_from_pretrained.py index 49c678d..a200e73 100644 --- a/TransformerLens/transformer_lens/loading_from_pretrained.py +++ b/TransformerLens/transformer_lens/loading_from_pretrained.py @@ -122,7 +122,9 @@ "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3-70B", "meta-llama/Meta-Llama-3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", @@ -809,6 +811,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif "Meta-Llama-3.1-8B" in official_model_name: + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 14336, + "n_layers": 32, + "n_ctx": 8192, + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } elif "Meta-Llama-3-70B" in official_model_name: cfg_dict = { "d_model": 8192, diff --git a/pdm.lock b/pdm.lock index 5b12cf3..d18e932 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:912a3ded5baf368f138e3f3b1ce24003ffffcf59a027a164404f5448a03733ea" +content_hash = "sha256:5c71418ab629971629ad4f44f415d2e9f8dddf153f32c46bc11770d739427917" [[package]] name = "accelerate" @@ -2509,13 +2509,13 @@ dependencies = [ [[package]] name = "transformers" -version = "4.41.2" +version = "4.43.1" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default", "dev"] dependencies = [ "filelock", - "huggingface-hub<1.0,>=0.23.0", + "huggingface-hub<1.0,>=0.23.2", "numpy>=1.17", "packaging>=20.0", "pyyaml>=5.1", @@ -2526,8 +2526,8 @@ dependencies = [ "tqdm>=4.27", ] files = [ - {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, - {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, + {file = "transformers-4.43.1-py3-none-any.whl", hash = "sha256:eb44b731902e062acbaff196ae4896d7cb3494ddf38275aa00a5fcfb5b34f17d"}, + {file = "transformers-4.43.1.tar.gz", hash = "sha256:662252c4d0e31b6684f68f68d5cc8206dd7f83da80eb3235be3dc5b3c9fdbdbd"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index b7d8f01..efadea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ authors = [ ] dependencies = [ "datasets>=2.17.0", + "transformers>=4.43.0", "einops>=0.7.0", "fastapi>=0.110.0", "matplotlib>=3.8.3", diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 416ae46..df7f8a7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -624,8 +624,8 @@ def from_initialization_searching( cfg: LanguageModelSAETrainingConfig, ): test_batch = activation_store.next( - batch_size=cfg.train_batch_size * 8 - ) # just random hard code xd + batch_size=cfg.train_batch_size + ) activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore if ( From ebb8cc8e881d9197c182086b1b06fe1bc6bfc172 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 24 Jul 2024 23:05:03 +0800 Subject: [PATCH 2/3] fix(sae): fix post process in transform_to_unit_decoder_norm --- src/lm_saes/sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index df7f8a7..d238422 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self): decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm + self.decoder.weight.data = self.decoder.weight.data / decoder_norm self.encoder.bias.data = self.encoder.bias.data * decoder_norm From ecb8fa515c4228e7fdd0a0b2a789095f2a4b23bb Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Thu, 25 Jul 2024 12:35:35 +0800 Subject: [PATCH 3/3] fix(training): use clip grad value instead of norm --- src/lm_saes/config.py | 2 +- src/lm_saes/sae_training.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..82e6769 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int | float = 0.1 lr_cool_down_steps: int | float = 0.1 train_batch_size: int = 4096 - clip_grad_norm: float = 0.0 + clip_grad_value: float = 0.0 remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index b55746d..dc583ac 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -145,8 +145,8 @@ def train_sae( if cfg.finetuning: loss = loss_data["l_rec"].mean() loss.backward() - if cfg.clip_grad_norm > 0: - torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) + if cfg.clip_grad_value > 0: + torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value) if cfg.remove_gradient_parallel_to_decoder_directions: sae.remove_gradient_parallel_to_decoder_directions() optimizer.step()