diff --git a/examples/loading_llamascope_saes.ipynb b/examples/loading_llamascope_saes.ipynb index db7689a..0b263b1 100644 --- a/examples/loading_llamascope_saes.ipynb +++ b/examples/loading_llamascope_saes.ipynb @@ -27,15 +27,11 @@ } ], "source": [ - "import os\n", "import torch\n", - "import transformers\n", - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from transformer_lens import HookedTransformer\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", - "from lm_saes import SparseAutoEncoder, SAEConfig\n", - "\n", - "from tqdm import tqdm\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig" + "from lm_saes import SparseAutoEncoder" ] }, { @@ -61,28 +57,18 @@ ], "source": [ "model_name = \"meta-llama/Llama-3.1-8B\"\n", - " \n", - "hf_model = AutoModelForCausalLM.from_pretrained(\n", - " (\n", - " model_name\n", - " if model_from_pretrained_path is None\n", - " else model_from_pretrained_path\n", - " ),\n", - ")\n", + "\n", + "hf_model = AutoModelForCausalLM.from_pretrained(model_name)\n", "\n", "hf_tokenizer = AutoTokenizer.from_pretrained(\n", - " (\n", - " model_name\n", - " if model_from_pretrained_path is None\n", - " else model_from_pretrained_path\n", - " ),\n", + " model_name,\n", " trust_remote_code=True,\n", " use_fast=True,\n", " add_bos_token=True,\n", ")\n", "model = HookedTransformer.from_pretrained_no_processing(\n", " model_name,\n", - " device='cuda',\n", + " device=\"cuda\",\n", " hf_model=hf_model,\n", " tokenizer=hf_tokenizer,\n", " dtype=torch.bfloat16,\n", @@ -127,7 +113,7 @@ } ], "source": [ - "sae = SparseAutoEncoder.from_pretrained('fnlp/Llama3_1-8B-Base-L15R-8x')" + "sae = SparseAutoEncoder.from_pretrained(\"fnlp/Llama3_1-8B-Base-L15R-8x\")" ] }, { @@ -179,7 +165,7 @@ ], "source": [ "# L0 Sparsity. The first token is which extremely out-of-distribution.\n", - "(sae.compute_loss(cache['blocks.15.hook_resid_post'])[1][1]['feature_acts'] > 0).sum(-1)" + "(sae.compute_loss(cache[\"blocks.15.hook_resid_post\"])[1][1][\"feature_acts\"] > 0).sum(-1)" ] }, { @@ -201,7 +187,10 @@ ], "source": [ "# Reconstruction loss\n", - "(sae.compute_loss(cache['blocks.15.hook_resid_post'][:, 1:])[1][1]['reconstructed'] - cache['blocks.15.hook_resid_post'][:, 1:]).pow(2).mean()" + "(\n", + " sae.compute_loss(cache[\"blocks.15.hook_resid_post\"][:, 1:])[1][1][\"reconstructed\"]\n", + " - cache[\"blocks.15.hook_resid_post\"][:, 1:]\n", + ").pow(2).mean()" ] }, { diff --git a/examples/programmatic/post_process_topk.py b/examples/programmatic/post_process_topk.py index 5bfca01..4372dbe 100644 --- a/examples/programmatic/post_process_topk.py +++ b/examples/programmatic/post_process_topk.py @@ -1,15 +1,20 @@ import os -from lm_saes import post_process_topk_to_jumprelu_runner, LanguageModelSAERunnerConfig, SAEConfig +from lm_saes import LanguageModelSAERunnerConfig, SAEConfig, post_process_topk_to_jumprelu_runner layer = 15 -hook_point_in = 'R' -hook_point_out = hook_point_in if hook_point_in != 'TC' else 'M' +hook_point_in = "R" +hook_point_out = hook_point_in if hook_point_in != "TC" else "M" exp_factor = 8 -HOOK_SUFFIX = {"M": "hook_mlp_out", "A": "hook_attn_out", "R": "hook_resid_post", "TC": "ln2.hook_normalized", - "Emb": "hook_resid_pre"} +HOOK_SUFFIX = { + "M": "hook_mlp_out", + "A": "hook_attn_out", + "R": "hook_resid_post", + "TC": "ln2.hook_normalized", + "Emb": "hook_resid_pre", +} hook_suffix_in = HOOK_SUFFIX[hook_point_in] @@ -19,7 +24,6 @@ sae_config = SAEConfig.from_pretrained(ckpt_path).to_dict() - model_name = "meta-llama/Llama-3.1-8B" # model_from_pretrained_path = "" @@ -65,4 +69,4 @@ ) ) -post_process_topk_to_jumprelu_runner(cfg) \ No newline at end of file +post_process_topk_to_jumprelu_runner(cfg) diff --git a/pdm.lock b/pdm.lock index a0b831c..33d4eb9 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:9902a8f3bb69f60d12a856756f0e2a2bbef583d4f312d798114c2235c2d2f340" +content_hash = "sha256:5e4b3c69c9533bfcd678cadfd52b7e960d124dd0084481ee38295949dd8df3f2" [[metadata.targets]] requires_python = "==3.10.*" @@ -2044,6 +2044,33 @@ files = [ {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, ] +[[package]] +name = "ruff" +version = "0.7.1" +requires_python = ">=3.7" +summary = "An extremely fast Python linter and code formatter, written in Rust." +groups = ["dev"] +files = [ + {file = "ruff-0.7.1-py3-none-linux_armv6l.whl", hash = "sha256:cb1bc5ed9403daa7da05475d615739cc0212e861b7306f314379d958592aaa89"}, + {file = "ruff-0.7.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27c1c52a8d199a257ff1e5582d078eab7145129aa02721815ca8fa4f9612dc35"}, + {file = "ruff-0.7.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:588a34e1ef2ea55b4ddfec26bbe76bc866e92523d8c6cdec5e8aceefeff02d99"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94fc32f9cdf72dc75c451e5f072758b118ab8100727168a3df58502b43a599ca"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:985818742b833bffa543a84d1cc11b5e6871de1b4e0ac3060a59a2bae3969250"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32f1e8a192e261366c702c5fb2ece9f68d26625f198a25c408861c16dc2dea9c"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:699085bf05819588551b11751eff33e9ca58b1b86a6843e1b082a7de40da1565"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:344cc2b0814047dc8c3a8ff2cd1f3d808bb23c6658db830d25147339d9bf9ea7"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4316bbf69d5a859cc937890c7ac7a6551252b6a01b1d2c97e8fc96e45a7c8b4a"}, + {file = "ruff-0.7.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79d3af9dca4c56043e738a4d6dd1e9444b6d6c10598ac52d146e331eb155a8ad"}, + {file = "ruff-0.7.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c5c121b46abde94a505175524e51891f829414e093cd8326d6e741ecfc0a9112"}, + {file = "ruff-0.7.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8422104078324ea250886954e48f1373a8fe7de59283d747c3a7eca050b4e378"}, + {file = "ruff-0.7.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:56aad830af8a9db644e80098fe4984a948e2b6fc2e73891538f43bbe478461b8"}, + {file = "ruff-0.7.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:658304f02f68d3a83c998ad8bf91f9b4f53e93e5412b8f2388359d55869727fd"}, + {file = "ruff-0.7.1-py3-none-win32.whl", hash = "sha256:b517a2011333eb7ce2d402652ecaa0ac1a30c114fbbd55c6b8ee466a7f600ee9"}, + {file = "ruff-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f38c41fcde1728736b4eb2b18850f6d1e3eedd9678c914dede554a70d5241307"}, + {file = "ruff-0.7.1-py3-none-win_arm64.whl", hash = "sha256:19aa200ec824c0f36d0c9114c8ec0087082021732979a359d6f3c390a6ff2a37"}, + {file = "ruff-0.7.1.tar.gz", hash = "sha256:9d8a41d4aa2dad1575adb98a82870cf5db5f76b2938cf2206c22c940034a36f4"}, +] + [[package]] name = "safetensors" version = "0.4.5" diff --git a/pyproject.toml b/pyproject.toml index 7503799..1ff5aed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "nbformat>=5.10.4", "kaleido==0.2.1", "pre-commit>=4.0.1", + "ruff>=0.7.1", ] [tool.mypy]