From 81c84aa0eefde5a7273aab056f008f76186473f1 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Thu, 13 Jun 2024 21:36:48 +0800 Subject: [PATCH] feat(runner): load tokenizer manually --- src/lm_saes/runner.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index c13eb9a..e5f0224 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -119,11 +119,22 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): cache_dir=cfg.lm.cache_dir, local_files_only=cfg.lm.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) model.eval() @@ -239,11 +250,22 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): cache_dir=cfg.lm.cache_dir, local_files_only=cfg.lm.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) model.eval()