From 6b0ba6905a8c8698a7def067cb10b322da4e2ab7 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Thu, 18 Jul 2024 20:32:56 +0800 Subject: [PATCH] fix: fix some bugs encountered during the initialization of SAE and the retrieval of next token in a tensor parallel environment. --- src/lm_saes/activation/activation_store.py | 3 +- src/lm_saes/sae.py | 79 +++++++++++++--------- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/lm_saes/activation/activation_store.py b/src/lm_saes/activation/activation_store.py index 49317a2..820f003 100644 --- a/src/lm_saes/activation/activation_store.py +++ b/src/lm_saes/activation/activation_store.py @@ -113,7 +113,8 @@ def next_tokens(self, batch_size: int) -> torch.Tensor | None: if self.tp_size > 1: # TODO next_tokens = self.act_source.next_tokens(batch_size) - funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"]) + # funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"]) + dist.broadcast(next_tokens, src=0) return next_tokens else: return self.act_source.next_tokens(batch_size) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index d04ad44..a5be1ed 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -26,6 +26,7 @@ distribute_tensor, ) + class SparseAutoEncoder(HookedRootModule): """Sparse AutoEncoder model. @@ -78,7 +79,7 @@ def __init__(self, cfg: SAEConfig): dtype=cfg.dtype, ) torch.nn.init.kaiming_uniform_(self.decoder.weight) - self.set_decoder_norm_to_fixed_norm() + self.set_decoder_norm_to_fixed_norm(during_init=True) self.train_base_parameters() @@ -97,7 +98,7 @@ def initialize_parameters(self): torch.nn.init.kaiming_uniform_(self.decoder.weight) self.set_decoder_norm_to_fixed_norm( - self.cfg.init_decoder_norm, force_exact=True + self.cfg.init_decoder_norm, force_exact=True, during_init=True ) if self.cfg.use_decoder_bias: @@ -356,14 +357,7 @@ def compute_loss( # l_l1: (batch,) if self.cfg.sparsity_include_decoder_norm: - # if self.cfg.tp_size > 1: - # decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0) - # decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]) - # decoder_norm = ( - # decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() - # ) - # else: - # decoder_norm = torch.norm(self.decoder.weight, p=2, dim=0) + l_l1 = torch.norm( feature_acts_normed * self.decoder_norm(), p=self.cfg.lp, @@ -381,7 +375,9 @@ def compute_loss( and dead_feature_mask.sum() > 0 ): # ghost protocol - assert self.cfg.tp_size == 1, "Ghost protocol not supported in tensor parallel training" + assert ( + self.cfg.tp_size == 1 + ), "Ghost protocol not supported in tensor parallel training" # 1. residual = label_normed - reconstructed_normed residual_centred = residual - residual.mean(dim=0, keepdim=True) @@ -462,11 +458,14 @@ def update_l1_coefficient(self, training_step): @torch.no_grad() def set_decoder_norm_to_fixed_norm( - self, value: float | None = 1.0, force_exact: bool | None = None + self, + value: float | None = 1.0, + force_exact: bool | None = None, + during_init: bool = False, ): if value is None: return - decoder_norm = self.decoder_norm(keepdim=True) + decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init) if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm if force_exact: @@ -653,14 +652,16 @@ def from_initialization_searching( test_sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - assert self.cfg.tp_size == 1, "Search for initial decoder norm not supported in tensor parallel training" - def grid_search_best_init_norm(search_range: List[float]) -> float: losses: Dict[float, float] = {} - + for norm in search_range: - test_sae.set_decoder_norm_to_fixed_norm(norm, force_exact=True) - test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous() + test_sae.set_decoder_norm_to_fixed_norm( + norm, force_exact=True, during_init=True + ) + test_sae.encoder.weight.data = ( + test_sae.decoder.weight.data.T.clone().contiguous() + ) mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore losses[norm] = mse best_norm = min(losses, key=losses.get) # type: ignore @@ -681,7 +682,9 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: test_sae.set_decoder_norm_to_fixed_norm( best_norm_fine_grained, force_exact=True ) - test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous() + test_sae.encoder.weight.data = ( + test_sae.decoder.weight.data.T.clone().contiguous() + ) return test_sae @@ -707,25 +710,39 @@ def save_pretrained(self, ckpt_path: str) -> None: f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats." ) - def decoder_norm(self, keepdim: bool = False): + def decoder_norm(self, keepdim: bool = False, during_init: bool = False): # We suspect that using torch.norm on dtensor may lead to some bugs during the backward process that are difficult to pinpoint and resolve. Therefore, we first convert the decoder weight from dtensor to tensor for norm calculation, and then redistribute it to different nodes. - if self.cfg.tp_size == 1: + if self.cfg.tp_size == 1 or during_init: return torch.norm(self.decoder.weight, p=2, dim=0, keepdim=keepdim) else: - decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim) - decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(int(keepdim))]) - decoder_norm = ( - decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() + decoder_norm = torch.norm( + self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim + ) + decoder_norm = DTensor.from_local( + decoder_norm, + device_mesh=self.device_mesh["tp"], + placements=[Shard(int(keepdim))], ) + decoder_norm = decoder_norm.redistribute( + placements=[Replicate()], async_op=True + ).to_local() return decoder_norm - def encoder_norm(self, keepdim: bool = False): - if self.cfg.tp_size == 1: + def encoder_norm( + self, + keepdim: bool = False, + during_init: bool = False, + ): + if self.cfg.tp_size == 1 or during_init: return torch.norm(self.encoder.weight, p=2, dim=1, keepdim=keepdim) else: - encoder_norm = torch.norm(self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim) - encoder_norm = DTensor.from_local(encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]) - encoder_norm = ( - encoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() + encoder_norm = torch.norm( + self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim + ) + encoder_norm = DTensor.from_local( + encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)] ) + encoder_norm = encoder_norm.redistribute( + placements=[Replicate()], async_op=True + ).to_local() return encoder_norm