Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sae): transform decoder_norm and encoder_norm to dtensor under tensor parallel settings #42

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from builtins import print
from importlib.metadata import version
import os
from typing import Dict, Literal, Union, overload, List
Expand Down Expand Up @@ -117,7 +118,7 @@ def initialize_parameters(self):
if self.cfg.init_encoder_with_decoder_transpose:
self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous()
else:
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm)
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm, during_init=True)

def train_base_parameters(self):
"""Set the base parameters to be trained."""
Expand Down Expand Up @@ -481,6 +482,13 @@ def set_decoder_norm_to_fixed_norm(
decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init)
if force_exact is None:
force_exact = self.cfg.decoder_exactly_fixed_norm


if self.cfg.tp_size > 1 and not during_init:
decoder_norm = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)

if force_exact:
self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm
else:
Expand All @@ -490,15 +498,19 @@ def set_decoder_norm_to_fixed_norm(
)

@torch.no_grad()
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0):
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: bool = False):
if self.cfg.use_glu_encoder:
raise NotImplementedError("GLU encoder not supported")
if value is None:
print(
f"Encoder norm is not set to a fixed value, using random initialization."
)
return
encoder_norm = self.encoder_norm(keepdim=True)
encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init)
if self.cfg.tp_size > 1 and not during_init:
encoder_norm = distribute_tensor(
encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm

@torch.no_grad()
Expand All @@ -515,10 +527,25 @@ def transform_to_unit_decoder_norm(self):
raise NotImplementedError("GLU encoder not supported")

decoder_norm = self.decoder_norm() # (d_sae,)
self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None]
self.decoder.weight.data = self.decoder.weight.data / decoder_norm

self.encoder.bias.data = self.encoder.bias.data * decoder_norm
if self.cfg.tp_size > 1:
decoder_norm_en = distribute_tensor(
decoder_norm[:, None], device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
decoder_norm_de = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
dencoder_norm_bias = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
else:
decoder_norm_en = decoder_norm[:, None]
decoder_norm_de = decoder_norm
dencoder_norm_bias = decoder_norm

self.encoder.weight.data = self.encoder.weight.data * decoder_norm_en
self.decoder.weight.data = self.decoder.weight.data / decoder_norm_de

self.encoder.bias.data = self.encoder.bias.data * dencoder_norm_bias

@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
Expand Down