Skip to content

Commit

Permalink
Merge pull request #42 from OpenMOSS/dev
Browse files Browse the repository at this point in the history
fix(sae): transform decoder_norm and encoder_norm to dtensor under tensor parallel settings
  • Loading branch information
Hzfinfdu authored Jul 31, 2024
2 parents 63f0a13 + 0c0cc2d commit 476076f
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from builtins import print
from importlib.metadata import version
import os
from typing import Dict, Literal, Union, overload, List
Expand Down Expand Up @@ -117,7 +118,7 @@ def initialize_parameters(self):
if self.cfg.init_encoder_with_decoder_transpose:
self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous()
else:
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm)
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm, during_init=True)

def train_base_parameters(self):
"""Set the base parameters to be trained."""
Expand Down Expand Up @@ -481,6 +482,13 @@ def set_decoder_norm_to_fixed_norm(
decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init)
if force_exact is None:
force_exact = self.cfg.decoder_exactly_fixed_norm


if self.cfg.tp_size > 1 and not during_init:
decoder_norm = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)

if force_exact:
self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm
else:
Expand All @@ -490,15 +498,19 @@ def set_decoder_norm_to_fixed_norm(
)

@torch.no_grad()
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0):
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: bool = False):
if self.cfg.use_glu_encoder:
raise NotImplementedError("GLU encoder not supported")
if value is None:
print(
f"Encoder norm is not set to a fixed value, using random initialization."
)
return
encoder_norm = self.encoder_norm(keepdim=True)
encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init)
if self.cfg.tp_size > 1 and not during_init:
encoder_norm = distribute_tensor(
encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm

@torch.no_grad()
Expand All @@ -515,10 +527,25 @@ def transform_to_unit_decoder_norm(self):
raise NotImplementedError("GLU encoder not supported")

decoder_norm = self.decoder_norm() # (d_sae,)
self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None]
self.decoder.weight.data = self.decoder.weight.data / decoder_norm

self.encoder.bias.data = self.encoder.bias.data * decoder_norm
if self.cfg.tp_size > 1:
decoder_norm_en = distribute_tensor(
decoder_norm[:, None], device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
decoder_norm_de = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
dencoder_norm_bias = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
else:
decoder_norm_en = decoder_norm[:, None]
decoder_norm_de = decoder_norm
dencoder_norm_bias = decoder_norm

self.encoder.weight.data = self.encoder.weight.data * decoder_norm_en
self.decoder.weight.data = self.decoder.weight.data / decoder_norm_de

self.encoder.bias.data = self.encoder.bias.data * dencoder_norm_bias

@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
Expand Down

0 comments on commit 476076f

Please sign in to comment.