Skip to content

Commit

Permalink
feat(runner): load tokenizer manually
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 committed Jun 13, 2024
1 parent 3aceba0 commit 81c84aa
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,22 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig):
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(
cfg.lm.model_name
if cfg.lm.model_from_pretrained_path is None
else cfg.lm.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.eval()
Expand Down Expand Up @@ -239,11 +250,22 @@ def activation_generation_runner(cfg: ActivationGenerationConfig):
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(
cfg.lm.model_name
if cfg.lm.model_from_pretrained_path is None
else cfg.lm.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.eval()
Expand Down

0 comments on commit 81c84aa

Please sign in to comment.