From 7f9e69ec600036710529ba28b8b9a6210e1b184e Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Mon, 22 Jul 2024 17:19:55 +0800 Subject: [PATCH] fix: typo --- src/lm_saes/activation/activation_store.py | 6 +++--- src/lm_saes/sae_training.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lm_saes/activation/activation_store.py b/src/lm_saes/activation/activation_store.py index 820f003..2105021 100644 --- a/src/lm_saes/activation/activation_store.py +++ b/src/lm_saes/activation/activation_store.py @@ -24,7 +24,7 @@ def __init__(self, act_source: ActivationSource, cfg: ActivationStoreConfig): self.act_source = act_source self.buffer_size = cfg.n_tokens_in_buffer self.device = cfg.device - self.ddp_size = cfg.ddp_size # 1 8 + self.ddp_size = cfg.ddp_size self.tp_size = cfg.tp_size self._store: Dict[str, torch.Tensor] = {} self._all_gather_buffer: Dict[str, torch.Tensor] = {} @@ -111,9 +111,9 @@ def next(self, batch_size) -> Dict[str, torch.Tensor] | None: def next_tokens(self, batch_size: int) -> torch.Tensor | None: if self.tp_size > 1: - # TODO + # TODO: only get next token from the root process next_tokens = self.act_source.next_tokens(batch_size) - # funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"]) + # funcol.broadcast does not work and we dont know why dist.broadcast(next_tokens, src=0) return next_tokens else: diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 312148c..3761997 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -85,10 +85,9 @@ def train_sae( sae.parallelize_plan = plan elif cfg.sae.ddp_size > 1: + # parallelize_module does not work with DDP _ = DDP(sae, device_mesh=sae.device_mesh["ddp"]) - # sae = parallelize_module( - # sae, device_mesh=sae.device_mesh["ddp"], parallelize_plan={} - # ) + optimizer = Adam(sae.parameters(), lr=cfg.lr, betas=cfg.betas)