Skip to content

Commit

Permalink
fix: resolve DDP-related synchronization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 committed Jun 26, 2024
1 parent 74ef9dd commit c73c64a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
from lm_saes.sae_training import prune_sae, train_sae
from lm_saes.analysis.sample_feature_activations import sample_feature_activations
from lm_saes.analysis.features_to_logits import features_to_logits

from torch.nn.parallel import DistributedDataParallel as DDP

def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
if (not cfg.use_ddp) or cfg.rank == 0:
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
Expand Down Expand Up @@ -68,7 +69,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)

if cfg.use_ddp:
_ = DDP(model, device_ids=[cfg.rank])
_ = DDP(sae, device_ids=[cfg.rank])
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

Expand Down

0 comments on commit c73c64a

Please sign in to comment.