Skip to content

Commit

Permalink
fix: fix some bugs encountered during the initialization of SAE and t…
Browse files Browse the repository at this point in the history
…he retrieval of next token in a tensor parallel environment.
  • Loading branch information
Frankstein73 committed Jul 18, 2024
1 parent d555c98 commit 6b0ba69
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/lm_saes/activation/activation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def next_tokens(self, batch_size: int) -> torch.Tensor | None:
if self.tp_size > 1:
# TODO
next_tokens = self.act_source.next_tokens(batch_size)
funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"])
# funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"])
dist.broadcast(next_tokens, src=0)
return next_tokens
else:
return self.act_source.next_tokens(batch_size)
Expand Down
79 changes: 48 additions & 31 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
distribute_tensor,
)


class SparseAutoEncoder(HookedRootModule):
"""Sparse AutoEncoder model.
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(self, cfg: SAEConfig):
dtype=cfg.dtype,
)
torch.nn.init.kaiming_uniform_(self.decoder.weight)
self.set_decoder_norm_to_fixed_norm()
self.set_decoder_norm_to_fixed_norm(during_init=True)

self.train_base_parameters()

Expand All @@ -97,7 +98,7 @@ def initialize_parameters(self):

torch.nn.init.kaiming_uniform_(self.decoder.weight)
self.set_decoder_norm_to_fixed_norm(
self.cfg.init_decoder_norm, force_exact=True
self.cfg.init_decoder_norm, force_exact=True, during_init=True
)

if self.cfg.use_decoder_bias:
Expand Down Expand Up @@ -356,14 +357,7 @@ def compute_loss(

# l_l1: (batch,)
if self.cfg.sparsity_include_decoder_norm:
# if self.cfg.tp_size > 1:
# decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0)
# decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)])
# decoder_norm = (
# decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
# )
# else:
# decoder_norm = torch.norm(self.decoder.weight, p=2, dim=0)

l_l1 = torch.norm(
feature_acts_normed * self.decoder_norm(),
p=self.cfg.lp,
Expand All @@ -381,7 +375,9 @@ def compute_loss(
and dead_feature_mask.sum() > 0
):
# ghost protocol
assert self.cfg.tp_size == 1, "Ghost protocol not supported in tensor parallel training"
assert (
self.cfg.tp_size == 1
), "Ghost protocol not supported in tensor parallel training"
# 1.
residual = label_normed - reconstructed_normed
residual_centred = residual - residual.mean(dim=0, keepdim=True)
Expand Down Expand Up @@ -462,11 +458,14 @@ def update_l1_coefficient(self, training_step):

@torch.no_grad()
def set_decoder_norm_to_fixed_norm(
self, value: float | None = 1.0, force_exact: bool | None = None
self,
value: float | None = 1.0,
force_exact: bool | None = None,
during_init: bool = False,
):
if value is None:
return
decoder_norm = self.decoder_norm(keepdim=True)
decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init)
if force_exact is None:
force_exact = self.cfg.decoder_exactly_fixed_norm
if force_exact:
Expand Down Expand Up @@ -653,14 +652,16 @@ def from_initialization_searching(

test_sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

assert self.cfg.tp_size == 1, "Search for initial decoder norm not supported in tensor parallel training"

def grid_search_best_init_norm(search_range: List[float]) -> float:
losses: Dict[float, float] = {}

for norm in search_range:
test_sae.set_decoder_norm_to_fixed_norm(norm, force_exact=True)
test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous()
test_sae.set_decoder_norm_to_fixed_norm(
norm, force_exact=True, during_init=True
)
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
)
mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore
losses[norm] = mse
best_norm = min(losses, key=losses.get) # type: ignore
Expand All @@ -681,7 +682,9 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
test_sae.set_decoder_norm_to_fixed_norm(
best_norm_fine_grained, force_exact=True
)
test_sae.encoder.weight.data = test_sae.decoder.weight.data.T.clone().contiguous()
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
)

return test_sae

Expand All @@ -707,25 +710,39 @@ def save_pretrained(self, ckpt_path: str) -> None:
f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats."
)

def decoder_norm(self, keepdim: bool = False):
def decoder_norm(self, keepdim: bool = False, during_init: bool = False):
# We suspect that using torch.norm on dtensor may lead to some bugs during the backward process that are difficult to pinpoint and resolve. Therefore, we first convert the decoder weight from dtensor to tensor for norm calculation, and then redistribute it to different nodes.
if self.cfg.tp_size == 1:
if self.cfg.tp_size == 1 or during_init:
return torch.norm(self.decoder.weight, p=2, dim=0, keepdim=keepdim)
else:
decoder_norm = torch.norm(self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim)
decoder_norm = DTensor.from_local(decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(int(keepdim))])
decoder_norm = (
decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
decoder_norm = torch.norm(
self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim
)
decoder_norm = DTensor.from_local(
decoder_norm,
device_mesh=self.device_mesh["tp"],
placements=[Shard(int(keepdim))],
)
decoder_norm = decoder_norm.redistribute(
placements=[Replicate()], async_op=True
).to_local()
return decoder_norm

def encoder_norm(self, keepdim: bool = False):
if self.cfg.tp_size == 1:
def encoder_norm(
self,
keepdim: bool = False,
during_init: bool = False,
):
if self.cfg.tp_size == 1 or during_init:
return torch.norm(self.encoder.weight, p=2, dim=1, keepdim=keepdim)
else:
encoder_norm = torch.norm(self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim)
encoder_norm = DTensor.from_local(encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)])
encoder_norm = (
encoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
encoder_norm = torch.norm(
self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim
)
encoder_norm = DTensor.from_local(
encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]
)
encoder_norm = encoder_norm.redistribute(
placements=[Replicate()], async_op=True
).to_local()
return encoder_norm

0 comments on commit 6b0ba69

Please sign in to comment.