Skip to content

Commit

Permalink
fix(sae): do not init device mesh in single device mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu authored and Hzfinfdu committed Jul 31, 2024
1 parent 7785d67 commit 63f0a13
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def __init__(self, cfg: SAEConfig):
)
torch.nn.init.kaiming_uniform_(self.encoder.weight)
torch.nn.init.zeros_(self.encoder.bias)
self.device_mesh = init_device_mesh(
"cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp")
)
if cfg.tp_size > 1 or cfg.ddp_size > 1:
self.device_mesh = init_device_mesh(
"cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp")
)

if cfg.use_glu_encoder:

Expand Down

0 comments on commit 63f0a13

Please sign in to comment.