From 83f863cc9c3da41f39ecd78743af7d57fb6a5a07 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Thu, 13 Jun 2024 20:08:38 +0800 Subject: [PATCH] feat(runner): load tokenizer manually --- pdm.lock | 12 +++---- pyproject.toml | 2 +- src/lm_saes/runner.py | 75 +++++++++++++++++++++++++++++++++--------- ui/bun.lockb | Bin 257868 -> 257820 bytes 4 files changed, 66 insertions(+), 23 deletions(-) diff --git a/pdm.lock b/pdm.lock index b13fe5f0..83c73ec2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:5266c91187a20b13682380660c9795b4ea9c2f2c2ad5370e97ab83ec920ece84" +content_hash = "sha256:4a00d257bb6f7996524921a49a47cd158a4c00501fdd3fa89589d4d0751fb434" [[package]] name = "accelerate" @@ -351,7 +351,7 @@ name = "exceptiongroup" version = "1.2.1" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default"] +groups = ["default", "dev"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, @@ -615,7 +615,7 @@ name = "iniconfig" version = "2.0.0" requires_python = ">=3.7" summary = "brain-dead simple config-ini parsing" -groups = ["default"] +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -1231,7 +1231,7 @@ name = "pluggy" version = "1.5.0" requires_python = ">=3.8" summary = "plugin and hook calling mechanisms for python" -groups = ["default"] +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1406,7 +1406,7 @@ name = "pytest" version = "8.2.1" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" -groups = ["default"] +groups = ["dev"] dependencies = [ "colorama; sys_platform == \"win32\"", "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", @@ -1861,7 +1861,7 @@ name = "tomli" version = "2.0.1" requires_python = ">=3.7" summary = "A lil' TOML parser" -groups = ["default", "dev"] +groups = ["dev"] marker = "python_version < \"3.11\"" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, diff --git a/pyproject.toml b/pyproject.toml index 608ce125..1b499663 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "numpy>=1.26.4", "pandas>=2.2.1", "pymongo>=4.6.3", - "pytest>=8.0.1", "tensorboardX>=2.6.2.2", "torch>=2.2.0", "tqdm>=4.66.2", @@ -41,6 +40,7 @@ license = {text = "MIT"} dev = [ "-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens", "mypy>=1.10.0", + "pytest>=8.0.1", ] [tool.mypy] diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 336ebb9d..ad868a8d 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -147,11 +147,23 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): hf_model = AutoModelForCausalLM.from_pretrained( cfg.model_name, cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only ) + + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() @@ -206,38 +218,58 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() client = MongoClient(cfg.mongo_uri, cfg.mongo_db) + client.remove_dictionary(cfg.exp_name, cfg.exp_series) client.create_dictionary(cfg.exp_name, cfg.d_sae, cfg.exp_series) for chunk_id in range(cfg.n_sae_chunks): activation_store = ActivationStore.from_config(model=model, cfg=cfg) - result = sample_feature_activations(sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks) + result = sample_feature_activations( + sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks + ) for i in range(len(result["index"].cpu().numpy().tolist())): - client.update_feature(cfg.exp_name, result["index"][i].item(), { - "act_times": result["act_times"][i].item(), - "max_feature_acts": result["max_feature_acts"][i].item(), - "feature_acts_all": result["feature_acts_all"][i] - .cpu() - .float() - .numpy(), # use .float() to convert bfloat16 to float32 - "analysis": [ - { - "name": v["name"], - "feature_acts": v["feature_acts"][i].cpu().float().numpy(), - "contexts": v["contexts"][i].cpu().numpy(), - } for v in result["analysis"] - ] - }, dictionary_series=cfg.exp_series) + client.update_feature( + cfg.exp_name, + result["index"][i].item(), + { + "act_times": result["act_times"][i].item(), + "max_feature_acts": result["max_feature_acts"][i].item(), + "feature_acts_all": result["feature_acts_all"][i] + .cpu() + .float() + .numpy(), # use .float() to convert bfloat16 to float32 + "analysis": [ + { + "name": v["name"], + "feature_acts": v["feature_acts"][i].cpu().float().numpy(), + "contexts": v["contexts"][i].cpu().numpy(), + } + for v in result["analysis"] + ], + }, + dictionary_series=cfg.exp_series, + ) del result del activation_store @@ -257,11 +289,22 @@ def features_to_logits_runner(cfg: FeaturesDecoderConfig): cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() diff --git a/ui/bun.lockb b/ui/bun.lockb index 1230fbfe649c403ce3ce000c85837588791d6b5f..6ee4bfe920e3f0ad8446c937eaafc9ba7fefda40 100755 GIT binary patch delta 5628 zcmXxjX|x+<9mnyxDTKDr6j~aVmh^TZEl>taZNP=FD6t|2tr$@3bfFDP%d`{-$dcOy z1}e9q1<4h~NVP^>8pS2IZgC&C8uT1b&hdq#uktd7;}R-r@R{HDaZaA+^ZdV=Hj{hX z6Vcokqq&EQW0SAk6#i{&?78&cZ!x>CJ=C80M_ek}yV|q=j7wE}PrEyeOQ1c{o_jqm zHSK-v`IB*}YaeJY{3|XE?L+OwH{#ONKGI(LcU)T9lmCg&p7yr()PLhr*51*c9>vAi z9%|3L8JCLoZoEIA9gFv>_FlYCxnm-L2yzL<8uTGQL$MA6D4e3$fFTr5RcyitN=d~Q zBu~@iK^syt70b|p^ejamLdeWktUwpCrz=*W2kse)0Ys3SqgaDJY6fpm3pL1BOu4e}kzejG(kaceWt8Qj-U5NL{2@h7P1JR`el+%q5By z=t6dtVikJeUaA;C1UXl+27Sn1rdWpo6yBoPfFTrDD>h*SrOOptkX)n5gEpkzs#t~& zq}M9?5JD!WSb;8NuTZQ)58Nvi1Bf8^HpLqBA^It-xj4#frxq4-Y4CXAqTm0}B$ zc}*U)A$7H489I=@M$v~5GS@0rpbOdS6syn!_g#tsM38&8Vh#F`zh1Ep11J;}8!&|8 zdlZ{6g3>z079`)R$%8heZcr>k2h#6T^dW@IjfxfMLiYWNRp^0RR16@3+)au#=tKSk zigg%3;e(0|7(#KqViQJC`jBD^k~eGepbe=HE0&=H>5`%kA!Ke*tUwpCw<=bl2ku7{ z1Bf8^QN~-!u?!tZ z?@;t1gv>pP73f0NSFA!0+?|R6M3CF1Sc5*~?^UeB01CSm8!&|8eTq#OLFsOr|3fnnf;0t=tA~!#VYi`eL^vS2y%g94f>FOQn3yL zC_JUufFTqQC^lgPrB5ifAo;W=589BLR4hXW(gziN2q9BbtUwpC&nQ-*2ks%o03yg8 zR;)oE^3N*PVE~2a6dN#v;wKfGFoM$aiY-XiHF?m6)RbZwI*|U9q7NZtUQnz+7qXvL ztU?dm&nN~ELGDGx8uTImS;aaGpwLikzz~X`Q*6QrN-rt4AbCWS2W?1wUa<@vNPj`m zhY&JH6)Vt%>=zZQ&;z%r7(fKMFDcfb5BZlB>o9=AmlYc@gyL5en=pdXR~1{3{F){Y z+K@V?ScVRyTZ%q}kome|1-g*^hGG?Z;J%_5Km@sOD%PM6`EM!KVE~11D>h&V#qTIK zVFabNVhfVfnmlMj>bPPVI*|UZq7NZtzNc7$E@Z#2ScM+AKTr%Hg4_=kYtV;$N3jkA zDEvsV0YfPMSg{EsDE&mS1<9Xk@}Lc=pDC811L>bD`Vd0q7m5|=LN-*aLJ!4ahnlE2gBK^szC#WHjt{i>o5A!L59Sb@WT z99us<{)e&NvBYcprcPE9>n6N})toO^ahxr-CC$9nw%ctx&2c82Yi-%PFO+sulus=Y5`&M(-%Z^AV5au%Dl$F{xpzH!sG(I%XIHZC!7 zyS;J0ZA(p?G#?y4z~it^+yquM0ZHa@qp4N?ybCEf3un#zC+r_5sG%fy?2W`8=oDZ1?TxzV@w#uA$ zn-xFkGqzo7&i*xKiyz>SjjlOwH7kCA!?s;!&b!Qq#2@yoY3AjO$3G%-?V4V)Ke2em z*u3ehg2dIyIq^5%WLErzwoE-XnYd!=_+(=F)L$nPw~lR`y8d8d#hjhyy&Y!xW^J9? zdoZzV)=Tljn`_n7F*C-$-|1HmCJxS@wdn|N_|w~tB*MhZG4D!u^;&n$^qQlI?T^kr gJ|5reI474*zdSx;Ph!?e{bBRlnqIwZ#=;Z-1I1@AD*ylh delta 5698 zcmXxjYml2&9mnx!QwVLLX=!O_X$jjdq+CjXLW7qY#7nFQK>-8yl3wTrO3TyI(pV6) zwS|XU8(K=Uf*7gRcx@ELEGowPB&gcBUlWops zwl8*PpX$y&oEv%V!j@Z4dy<_*u)QAY63+aer1$vN~rdWkOWY18n!2oh+D%N2L`KV$8 zqGxF;Koerq6^qb<_zXoKLP*S1EI}KRXDgPW1F3Tq1L#6}mSP2ZkclZ)p%2;FiZvKO zZjNFdhLAs3u>sM!nhMZ_*gVA|v><+-q7NY?<|~$<4avA-89IMqXhJNZScDeDFH-a&gv6*~3EGfcqF9Cwq?RfM z(1rAvVg-7Txmd9ZeaK#-Sc3uN^xt5z4nxQ<(>ogwy;M^Hnh?88u?Q`QU#{pw2#G5c zOVEboa>X)qAa$i;09{C@6f4k!%vFk2=tK5u#TpDCca35lhLB&O*nsG@nhMZ_*ma6U zXhD3Xq7NY?(uyT$L-KmXGISvIHpKwCkbb*j1$vNqhhi1_kbS3O4F-@~rC5g{@^M-lABB4y1C50dygKt6~Lu zkomA;75b34nxS_uGoO+9hwTzgxE(Fi_n62UeSjT5_c+=pbg2p6wAELcXBbfanHI1!zKSqhb+S5Z|QeLkNk@ ziX~`6@@~a4bRe}wF@P?l#}zBkgG^De3Vq1lqgaChTA-z+v0zJq)pjd@IWOpgnU;w$@igg%5{z1hC zME7VaKoepm#UivI{&7VgLP+dYEI}KR4=I+R1F44<1L#8f6N(k+LFN&~D)b@yNyQor zAXiqb!w~WliVcYF(^P;a#6G21gcij2EBX*Z;(%fa+K_xyu?!tZJ*F5y7t(=Z1$vNq zT(Js$$UdQ1g8}58RII}g@=qx?Ao{eX0yH6ZP_YOth#yk)A%sLlu>@^MKBHKM4x|n% z2GE7{5ycAhAoHwZ75b1ps#t>o`?-6$9u(`b&xx=t1VoidE=C_A81t7(ni;igg%5{%eX2 zh<;sD0h$mysaS*-#2bn}gpl}#VhP%i{H9_VI*@uBhn+>aFNFogV%6&n!!iKYTHA@)GesXlNc>!}1Z_x$ie>0P z>KBRubRqpq#R~Kw^DD(F^db9e#TpDC_Z!7J3?Y9?u>sNFYAQezVr|7Dv>^UFMIS;) zyslU>Odk2e$n8_3zaQzGj4U`X`DQt?dSXq)^F}<+TV$4f{y&@SB~(0bt$WWjGoL(z zxXv{@-sakR*UmKUulCL6DYzCj=M^fRx52fuJa58V>EcEgrpDD)n?gen`?8;`A_bHx4Smaw61IS+h#uBJI}0TF8c0`^GyrfhwN}I zZrTac?4@?PcD^~!ckg?^wFRc#VV1qrF4q>CbKWfb>~`$}bDsFK`Cxmi2VJy(2RFIP z>~U?eX*au8GR?fa3(fkYS@zlM-e><_o-~)Smw(7K`%PYC)>^ad^N4BouA^o>&%Za} z?K9E5yd`Gsb04zb1bbX+*0^g2TpKfOn`@7{cCl#(&A)6f@R)0tm~+Loz_p}l>&>#y z<7(#nUuK4aS@xz+m}p+!rDkn#?J4)emzjB!S@wC_z3*~!-fWir*+Z^fVb1H_1)p(0 zyWF%LrY$iZcI`@go-jAGw?E=y%A9wZWpDbdYgd`GZ_FTInGkV?V?yV1x?yx_4n~rZ(<76ugON3}?3ZJ|jE!b(o?Lh+vTXiVbKGIpcC+l&>{Z4mZ#fhh zoB6u?q;GUKi>!YflFW78I$`X4TPGWGxf