From e97eb60e99ff411a38accd3affb8e986c908e023 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 31 Jul 2024 12:18:42 +0800 Subject: [PATCH 01/20] feat(sae): add a utils func to merge pre-enc bias into enc bias --- src/lm_saes/utils/convert_pre_enc_bias.py | 12 ++++++++++++ tests/unit/test_convert_pre_enc_bias.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 src/lm_saes/utils/convert_pre_enc_bias.py create mode 100644 tests/unit/test_convert_pre_enc_bias.py diff --git a/src/lm_saes/utils/convert_pre_enc_bias.py b/src/lm_saes/utils/convert_pre_enc_bias.py new file mode 100644 index 0000000..03e7b6f --- /dev/null +++ b/src/lm_saes/utils/convert_pre_enc_bias.py @@ -0,0 +1,12 @@ +from lm_saes.sae import SparseAutoEncoder +import torch + + +@torch.no_grad() +def merge_pre_enc_bias_to_enc_bias(sae: SparseAutoEncoder): + assert sae.cfg.apply_decoder_bias_to_pre_encoder + + sae.cfg.apply_decoder_bias_to_pre_encoder = False + sae.encoder.bias.data = sae.encoder.bias.data - sae.encoder.weight.data @ sae.decoder.bias.data + + return sae \ No newline at end of file diff --git a/tests/unit/test_convert_pre_enc_bias.py b/tests/unit/test_convert_pre_enc_bias.py new file mode 100644 index 0000000..f5cbb92 --- /dev/null +++ b/tests/unit/test_convert_pre_enc_bias.py @@ -0,0 +1,15 @@ +from lm_saes.sae import SparseAutoEncoder +from lm_saes.config import SAEConfig +from lm_saes.utils.convert_pre_enc_bias import merge_pre_enc_bias_to_enc_bias +import torch + +cfg = SAEConfig( + d_model=512, + expansion_factor=4, + apply_decoder_bias_to_pre_encoder=True, +) + +sae = SparseAutoEncoder(cfg) +sample = torch.randn(4, cfg.d_model) + +assert (sae(sample) == merge_pre_enc_bias_to_enc_bias(sae)(sample)).all() \ No newline at end of file From fd6c6ed11cfe841b686305788f63cc41723f793c Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 31 Jul 2024 14:30:11 +0800 Subject: [PATCH 02/20] fix(sae): do not init device mesh in single device mode --- src/lm_saes/sae.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index d238422..278211e 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -60,9 +60,10 @@ def __init__(self, cfg: SAEConfig): ) torch.nn.init.kaiming_uniform_(self.encoder.weight) torch.nn.init.zeros_(self.encoder.bias) - self.device_mesh = init_device_mesh( - "cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp") - ) + if cfg.tp_size > 1 or cfg.ddp_size > 1: + self.device_mesh = init_device_mesh( + "cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp") + ) if cfg.use_glu_encoder: From ea6dc74dd50a1fc7a677c493f9949928b39310ab Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Fri, 2 Aug 2024 01:30:14 +0800 Subject: [PATCH 03/20] feat(ui): create model page --- ui/src/components/app/navbar.tsx | 9 ++++++ ui/src/components/model/model-card.tsx | 40 ++++++++++++++++++++++++++ ui/src/main.tsx | 5 ++++ ui/src/routes/models/page.tsx | 15 ++++++++++ 4 files changed, 69 insertions(+) create mode 100644 ui/src/components/model/model-card.tsx create mode 100644 ui/src/routes/models/page.tsx diff --git a/ui/src/components/app/navbar.tsx b/ui/src/components/app/navbar.tsx index 34cf336..6279562 100644 --- a/ui/src/components/app/navbar.tsx +++ b/ui/src/components/app/navbar.tsx @@ -28,6 +28,15 @@ export const AppNavbar = () => { > Dictionaries + + Models + diff --git a/ui/src/components/model/model-card.tsx b/ui/src/components/model/model-card.tsx new file mode 100644 index 0000000..6774e4c --- /dev/null +++ b/ui/src/components/model/model-card.tsx @@ -0,0 +1,40 @@ +import { useState } from "react"; +import { Button } from "../ui/button"; +import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; +import { Textarea } from "../ui/textarea"; + +const ModelCustomInputArea = () => { + const [customInput, setCustomInput] = useState(""); + const submit = async () => {}; + const disabled = false; + return ( +
+

Custom Input

+