From 556f6cdc6e34e0fbf45824413ae954c804b4fe59 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Sat, 20 Jul 2024 13:04:03 +0800 Subject: [PATCH] feat(sae): Implement ckpt saving in tensor parallel environment. --- src/lm_saes/sae.py | 50 ++++++++++++++++++++++++++----------- src/lm_saes/sae_training.py | 3 +-- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 0e0ae40..edca8e6 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -25,6 +25,14 @@ distribute_module, distribute_tensor, ) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, + loss_parallel, +) + +from lm_saes.utils.misc import is_master class SparseAutoEncoder(HookedRootModule): @@ -253,7 +261,11 @@ def encode( label = x if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: - x = x - self.decoder.bias.to_local() if self.cfg.tp_size > 1 else x - self.decoder.bias + x = ( + x - self.decoder.bias.to_local() + if self.cfg.tp_size > 1 + else x - self.decoder.bias + ) x = x * self.compute_norm_factor(x, hook_point="in") @@ -688,27 +700,37 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: return test_sae + def get_full_state_dict(self) -> dict: + state_dict = self.state_dict() + if self.cfg.tp_size > 1: + state_dict = { + k: v.full_tensor() if isinstance(v, DTensor) else v + for k, v in state_dict.items() + } + return state_dict + def save_pretrained(self, ckpt_path: str) -> None: """Save the model to the checkpoint path. Args: ckpt_path (str): The path to save the model. If a directory, the model will be saved to the directory with the default filename `sae_weights.safetensors`. """ - if os.path.isdir(ckpt_path): ckpt_path = os.path.join(ckpt_path, "sae_weights.safetensors") - if ckpt_path.endswith(".safetensors"): - safe.save_file( - self.state_dict(), ckpt_path, {"version": version("lm-saes")} - ) - elif ckpt_path.endswith(".pt"): - torch.save( - {"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path - ) - else: - raise ValueError( - f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats." - ) + state_dict = self.get_full_state_dict() + if is_master(): + if ckpt_path.endswith(".safetensors"): + safe.save_file( + state_dict, ckpt_path, {"version": version("lm-saes")} + ) + elif ckpt_path.endswith(".pt"): + torch.save( + {"sae": state_dict, "version": version("lm-saes")}, ckpt_path + ) + else: + raise ValueError( + f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats." + ) def decoder_norm(self, keepdim: bool = False, during_init: bool = False): # We suspect that using torch.norm on dtensor may lead to some bugs during the backward process that are difficult to pinpoint and resolve. Therefore, we first convert the decoder weight from dtensor to tensor for norm calculation, and then redistribute it to different nodes. diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 7982be7..3d990c8 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -82,6 +82,7 @@ def train_sae( sae = parallelize_module( sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan ) + sae.parallelize_plan = plan elif cfg.sae.ddp_size > 1: _ = DDP(sae, device_mesh=sae.device_mesh["ddp"]) @@ -315,7 +316,6 @@ def train_sae( if ( len(checkpoint_thresholds) > 0 and n_training_tokens >= checkpoint_thresholds[0] - and is_master() ): # Save the model and optimizer state path = os.path.join( @@ -327,7 +327,6 @@ def train_sae( if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) - checkpoint_thresholds.pop(0) n_training_steps += 1