From d040d2153dc069fc48f9747b63854d91e1a420e7 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Fri, 19 Jul 2024 01:55:03 +0800 Subject: [PATCH] fix: convert decoder bias to local tensor while using tensor parallel --- src/lm_saes/sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index a5be1ed..0e0ae40 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -253,7 +253,7 @@ def encode( label = x if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: - x = x - self.decoder.bias + x = x - self.decoder.bias.to_local() if self.cfg.tp_size > 1 else x - self.decoder.bias x = x * self.compute_norm_factor(x, hook_point="in")