From a1594d78ce4567ced66ab886ecff86ff3b6f7d08 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 24 Jul 2024 19:13:57 +0800 Subject: [PATCH 1/4] feat(model): support llama3_1 --- .../loading_from_pretrained.py | 21 +++++++++++++++++++ pdm.lock | 10 ++++----- pyproject.toml | 1 + src/lm_saes/sae.py | 4 ++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/TransformerLens/transformer_lens/loading_from_pretrained.py b/TransformerLens/transformer_lens/loading_from_pretrained.py index 49c678d..a200e73 100644 --- a/TransformerLens/transformer_lens/loading_from_pretrained.py +++ b/TransformerLens/transformer_lens/loading_from_pretrained.py @@ -122,7 +122,9 @@ "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3-70B", "meta-llama/Meta-Llama-3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", @@ -809,6 +811,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif "Meta-Llama-3.1-8B" in official_model_name: + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 14336, + "n_layers": 32, + "n_ctx": 8192, + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } elif "Meta-Llama-3-70B" in official_model_name: cfg_dict = { "d_model": 8192, diff --git a/pdm.lock b/pdm.lock index 5b12cf3..d18e932 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:912a3ded5baf368f138e3f3b1ce24003ffffcf59a027a164404f5448a03733ea" +content_hash = "sha256:5c71418ab629971629ad4f44f415d2e9f8dddf153f32c46bc11770d739427917" [[package]] name = "accelerate" @@ -2509,13 +2509,13 @@ dependencies = [ [[package]] name = "transformers" -version = "4.41.2" +version = "4.43.1" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default", "dev"] dependencies = [ "filelock", - "huggingface-hub<1.0,>=0.23.0", + "huggingface-hub<1.0,>=0.23.2", "numpy>=1.17", "packaging>=20.0", "pyyaml>=5.1", @@ -2526,8 +2526,8 @@ dependencies = [ "tqdm>=4.27", ] files = [ - {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, - {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, + {file = "transformers-4.43.1-py3-none-any.whl", hash = "sha256:eb44b731902e062acbaff196ae4896d7cb3494ddf38275aa00a5fcfb5b34f17d"}, + {file = "transformers-4.43.1.tar.gz", hash = "sha256:662252c4d0e31b6684f68f68d5cc8206dd7f83da80eb3235be3dc5b3c9fdbdbd"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index b7d8f01..efadea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ authors = [ ] dependencies = [ "datasets>=2.17.0", + "transformers>=4.43.0", "einops>=0.7.0", "fastapi>=0.110.0", "matplotlib>=3.8.3", diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 416ae46..df7f8a7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -624,8 +624,8 @@ def from_initialization_searching( cfg: LanguageModelSAETrainingConfig, ): test_batch = activation_store.next( - batch_size=cfg.train_batch_size * 8 - ) # just random hard code xd + batch_size=cfg.train_batch_size + ) activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore if ( From ebb8cc8e881d9197c182086b1b06fe1bc6bfc172 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 24 Jul 2024 23:05:03 +0800 Subject: [PATCH 2/4] fix(sae): fix post process in transform_to_unit_decoder_norm --- src/lm_saes/sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index df7f8a7..d238422 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self): decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm + self.decoder.weight.data = self.decoder.weight.data / decoder_norm self.encoder.bias.data = self.encoder.bias.data * decoder_norm From ecb8fa515c4228e7fdd0a0b2a789095f2a4b23bb Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Thu, 25 Jul 2024 12:35:35 +0800 Subject: [PATCH 3/4] fix(training): use clip grad value instead of norm --- src/lm_saes/config.py | 2 +- src/lm_saes/sae_training.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..82e6769 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int | float = 0.1 lr_cool_down_steps: int | float = 0.1 train_batch_size: int = 4096 - clip_grad_norm: float = 0.0 + clip_grad_value: float = 0.0 remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index b55746d..dc583ac 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -145,8 +145,8 @@ def train_sae( if cfg.finetuning: loss = loss_data["l_rec"].mean() loss.backward() - if cfg.clip_grad_norm > 0: - torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) + if cfg.clip_grad_value > 0: + torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value) if cfg.remove_gradient_parallel_to_decoder_directions: sae.remove_gradient_parallel_to_decoder_directions() optimizer.step() From 90c1b9a4ab3c90325f00af5525c909b9551b820a Mon Sep 17 00:00:00 2001 From: StarConnor <105139493+StarConnor@users.noreply.github.com> Date: Sat, 27 Jul 2024 19:42:18 +0800 Subject: [PATCH 4/4] Add feature of side navigation --- ui/src/components/app/sidenav.tsx | 64 ++++++++++++++++++++++ ui/src/components/feature/feature-card.tsx | 17 ++++-- ui/src/globals.css | 16 ++++++ ui/src/routes/features/page.tsx | 4 +- 4 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 ui/src/components/app/sidenav.tsx diff --git a/ui/src/components/app/sidenav.tsx b/ui/src/components/app/sidenav.tsx new file mode 100644 index 0000000..98db50e --- /dev/null +++ b/ui/src/components/app/sidenav.tsx @@ -0,0 +1,64 @@ +import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; +import { useEffect, useState } from 'react'; + +export const SideNav = ({logitsExist} : {logitsExist : boolean}) => { + const [activeId, setActiveId] = useState(''); + // const idList = [{item:'Top'}, {item:'Hist.'}, {item:'Logits'}, {item:'Act.'}] + let idList + if (logitsExist){ + // idList = [{id:'Top'}, {id:'Hist.'}, {id:'Logits'}, {id:'Act.'}] + idList = ['Top', 'Hist.', 'Logits', 'Act.'] + } else{ + idList = ['Top', 'Hist.', 'Act.'] + } + + const handleScroll = () => { + const sections = document.querySelectorAll('div[id]'); + let currentSectionId = ''; + + sections.forEach(section => { + if (idList.indexOf(section.id) != -1){ + const rect = section.getBoundingClientRect(); + if (rect.top <= window.innerHeight / 2 && rect.bottom >= window.innerHeight / 2) { + currentSectionId = section.id; + } + } + }); + + setActiveId(currentSectionId); + }; + + useEffect(() => { + window.addEventListener('scroll', handleScroll); + + // Run the handler to set the initial active section + handleScroll(); + + return () => { + window.removeEventListener('scroll', handleScroll); + }; + }, ); + + return ( + + + + CONTENTS + + + +
+
    + {idList.map((item) => ( +
  • + + {item} + + {activeId === item &&
    } +
  • ))} +
+
+
+
+ ); +}; diff --git a/ui/src/components/feature/feature-card.tsx b/ui/src/components/feature/feature-card.tsx index fecc1a5..38cf273 100644 --- a/ui/src/components/feature/feature-card.tsx +++ b/ui/src/components/feature/feature-card.tsx @@ -88,7 +88,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => { const [showCustomInput, setShowCustomInput] = useState(false); return ( - + @@ -108,7 +108,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => { -
+

Activation Histogram

{
{feature.logits && ( -
+

Logits

@@ -180,10 +180,17 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
)} -
+
- {feature.sampleGroups.map((sampleGroup) => ( + {feature.sampleGroups.slice(0,feature.sampleGroups.length/2).map((sampleGroup) => ( + + {analysisNameMap(sampleGroup.analysisName)} + + ))} + + + {feature.sampleGroups.slice(feature.sampleGroups.length/2,feature.sampleGroups.length).map((sampleGroup) => ( {analysisNameMap(sampleGroup.analysisName)} diff --git a/ui/src/globals.css b/ui/src/globals.css index 9e24275..bdca90d 100644 --- a/ui/src/globals.css +++ b/ui/src/globals.css @@ -74,3 +74,19 @@ @apply bg-background text-foreground; } } + + +html { + scroll-behavior: smooth; +} +/* Side navigation styles */ +.side-nav { + position: fixed; + right: 0; + top: 0; + width: 125px; + height: 100%; + background-color: #f4f4f4; + /* padding: 10px; */ + box-shadow: -2px 0 5px rgba(0, 0, 0, 0.1); +} diff --git a/ui/src/routes/features/page.tsx b/ui/src/routes/features/page.tsx index f55925f..387747e 100644 --- a/ui/src/routes/features/page.tsx +++ b/ui/src/routes/features/page.tsx @@ -1,5 +1,6 @@ import { AppNavbar } from "@/components/app/navbar"; import { FeatureCard } from "@/components/feature/feature-card"; +import { SideNav} from "@/components/app/sidenav"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; @@ -90,7 +91,8 @@ export const FeaturesPage = () => { return (
-
+ +
Select dictionary: