diff --git a/pdm.lock b/pdm.lock index b13fe5f..83c73ec 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:5266c91187a20b13682380660c9795b4ea9c2f2c2ad5370e97ab83ec920ece84" +content_hash = "sha256:4a00d257bb6f7996524921a49a47cd158a4c00501fdd3fa89589d4d0751fb434" [[package]] name = "accelerate" @@ -351,7 +351,7 @@ name = "exceptiongroup" version = "1.2.1" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default"] +groups = ["default", "dev"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, @@ -615,7 +615,7 @@ name = "iniconfig" version = "2.0.0" requires_python = ">=3.7" summary = "brain-dead simple config-ini parsing" -groups = ["default"] +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -1231,7 +1231,7 @@ name = "pluggy" version = "1.5.0" requires_python = ">=3.8" summary = "plugin and hook calling mechanisms for python" -groups = ["default"] +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1406,7 +1406,7 @@ name = "pytest" version = "8.2.1" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" -groups = ["default"] +groups = ["dev"] dependencies = [ "colorama; sys_platform == \"win32\"", "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", @@ -1861,7 +1861,7 @@ name = "tomli" version = "2.0.1" requires_python = ">=3.7" summary = "A lil' TOML parser" -groups = ["default", "dev"] +groups = ["dev"] marker = "python_version < \"3.11\"" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, diff --git a/pyproject.toml b/pyproject.toml index 608ce12..1b49966 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "numpy>=1.26.4", "pandas>=2.2.1", "pymongo>=4.6.3", - "pytest>=8.0.1", "tensorboardX>=2.6.2.2", "torch>=2.2.0", "tqdm>=4.66.2", @@ -41,6 +40,7 @@ license = {text = "MIT"} dev = [ "-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens", "mypy>=1.10.0", + "pytest>=8.0.1", ] [tool.mypy] diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 336ebb9..ad868a8 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -147,11 +147,23 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): hf_model = AutoModelForCausalLM.from_pretrained( cfg.model_name, cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only ) + + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() @@ -206,38 +218,58 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() client = MongoClient(cfg.mongo_uri, cfg.mongo_db) + client.remove_dictionary(cfg.exp_name, cfg.exp_series) client.create_dictionary(cfg.exp_name, cfg.d_sae, cfg.exp_series) for chunk_id in range(cfg.n_sae_chunks): activation_store = ActivationStore.from_config(model=model, cfg=cfg) - result = sample_feature_activations(sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks) + result = sample_feature_activations( + sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks + ) for i in range(len(result["index"].cpu().numpy().tolist())): - client.update_feature(cfg.exp_name, result["index"][i].item(), { - "act_times": result["act_times"][i].item(), - "max_feature_acts": result["max_feature_acts"][i].item(), - "feature_acts_all": result["feature_acts_all"][i] - .cpu() - .float() - .numpy(), # use .float() to convert bfloat16 to float32 - "analysis": [ - { - "name": v["name"], - "feature_acts": v["feature_acts"][i].cpu().float().numpy(), - "contexts": v["contexts"][i].cpu().numpy(), - } for v in result["analysis"] - ] - }, dictionary_series=cfg.exp_series) + client.update_feature( + cfg.exp_name, + result["index"][i].item(), + { + "act_times": result["act_times"][i].item(), + "max_feature_acts": result["max_feature_acts"][i].item(), + "feature_acts_all": result["feature_acts_all"][i] + .cpu() + .float() + .numpy(), # use .float() to convert bfloat16 to float32 + "analysis": [ + { + "name": v["name"], + "feature_acts": v["feature_acts"][i].cpu().float().numpy(), + "contexts": v["contexts"][i].cpu().numpy(), + } + for v in result["analysis"] + ], + }, + dictionary_series=cfg.exp_series, + ) del result del activation_store @@ -257,11 +289,22 @@ def features_to_logits_runner(cfg: FeaturesDecoderConfig): cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.model_name + if cfg.model_from_pretrained_path is None + else cfg.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.model_name, device=cfg.device, cache_dir=cfg.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.dtype, ) model.eval() diff --git a/ui/bun.lockb b/ui/bun.lockb index 1230fbf..6ee4bfe 100755 Binary files a/ui/bun.lockb and b/ui/bun.lockb differ