Skip to content

Commit

Permalink
fix(example): fix error in loading example
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu committed Oct 30, 2024
1 parent ba1dbcc commit eb5088a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 32 deletions.
37 changes: 13 additions & 24 deletions examples/loading_llamascope_saes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
}
],
"source": [
"import os\n",
"import torch\n",
"import transformers\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from transformer_lens import HookedTransformer\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from lm_saes import SparseAutoEncoder, SAEConfig\n",
"\n",
"from tqdm import tqdm\n",
"from transformer_lens import HookedTransformer, HookedTransformerConfig"
"from lm_saes import SparseAutoEncoder"
]
},
{
Expand All @@ -61,28 +57,18 @@
],
"source": [
"model_name = \"meta-llama/Llama-3.1-8B\"\n",
" \n",
"hf_model = AutoModelForCausalLM.from_pretrained(\n",
" (\n",
" model_name\n",
" if model_from_pretrained_path is None\n",
" else model_from_pretrained_path\n",
" ),\n",
")\n",
"\n",
"hf_model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
"hf_tokenizer = AutoTokenizer.from_pretrained(\n",
" (\n",
" model_name\n",
" if model_from_pretrained_path is None\n",
" else model_from_pretrained_path\n",
" ),\n",
" model_name,\n",
" trust_remote_code=True,\n",
" use_fast=True,\n",
" add_bos_token=True,\n",
")\n",
"model = HookedTransformer.from_pretrained_no_processing(\n",
" model_name,\n",
" device='cuda',\n",
" device=\"cuda\",\n",
" hf_model=hf_model,\n",
" tokenizer=hf_tokenizer,\n",
" dtype=torch.bfloat16,\n",
Expand Down Expand Up @@ -127,7 +113,7 @@
}
],
"source": [
"sae = SparseAutoEncoder.from_pretrained('fnlp/Llama3_1-8B-Base-L15R-8x')"
"sae = SparseAutoEncoder.from_pretrained(\"fnlp/Llama3_1-8B-Base-L15R-8x\")"
]
},
{
Expand Down Expand Up @@ -179,7 +165,7 @@
],
"source": [
"# L0 Sparsity. The first token is <bos> which extremely out-of-distribution.\n",
"(sae.compute_loss(cache['blocks.15.hook_resid_post'])[1][1]['feature_acts'] > 0).sum(-1)"
"(sae.compute_loss(cache[\"blocks.15.hook_resid_post\"])[1][1][\"feature_acts\"] > 0).sum(-1)"
]
},
{
Expand All @@ -201,7 +187,10 @@
],
"source": [
"# Reconstruction loss\n",
"(sae.compute_loss(cache['blocks.15.hook_resid_post'][:, 1:])[1][1]['reconstructed'] - cache['blocks.15.hook_resid_post'][:, 1:]).pow(2).mean()"
"(\n",
" sae.compute_loss(cache[\"blocks.15.hook_resid_post\"][:, 1:])[1][1][\"reconstructed\"]\n",
" - cache[\"blocks.15.hook_resid_post\"][:, 1:]\n",
").pow(2).mean()"
]
},
{
Expand Down
18 changes: 11 additions & 7 deletions examples/programmatic/post_process_topk.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import os
from lm_saes import post_process_topk_to_jumprelu_runner, LanguageModelSAERunnerConfig, SAEConfig

from lm_saes import LanguageModelSAERunnerConfig, SAEConfig, post_process_topk_to_jumprelu_runner

layer = 15

hook_point_in = 'R'
hook_point_out = hook_point_in if hook_point_in != 'TC' else 'M'
hook_point_in = "R"
hook_point_out = hook_point_in if hook_point_in != "TC" else "M"
exp_factor = 8

HOOK_SUFFIX = {"M": "hook_mlp_out", "A": "hook_attn_out", "R": "hook_resid_post", "TC": "ln2.hook_normalized",
"Emb": "hook_resid_pre"}
HOOK_SUFFIX = {
"M": "hook_mlp_out",
"A": "hook_attn_out",
"R": "hook_resid_post",
"TC": "ln2.hook_normalized",
"Emb": "hook_resid_pre",
}


hook_suffix_in = HOOK_SUFFIX[hook_point_in]
Expand All @@ -19,7 +24,6 @@
sae_config = SAEConfig.from_pretrained(ckpt_path).to_dict()



model_name = "meta-llama/Llama-3.1-8B"
# model_from_pretrained_path = "<local_model_path>"

Expand Down Expand Up @@ -65,4 +69,4 @@
)
)

post_process_topk_to_jumprelu_runner(cfg)
post_process_topk_to_jumprelu_runner(cfg)
29 changes: 28 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dev = [
"nbformat>=5.10.4",
"kaleido==0.2.1",
"pre-commit>=4.0.1",
"ruff>=0.7.1",
]

[tool.mypy]
Expand Down

0 comments on commit eb5088a

Please sign in to comment.