diff --git a/server/.env.example b/server/.env.example index fbf3a73..4746301 100644 --- a/server/.env.example +++ b/server/.env.example @@ -1,5 +1,4 @@ MONGO_URI= # Must fill in -RESULT_DIR= # Must fill in DICTIONARY_SERIES= # Must fill in DICTIONARY_CKPT_NAME=final.pt diff --git a/server/app.py b/server/app.py index 53e5a75..b93c3f4 100644 --- a/server/app.py +++ b/server/app.py @@ -368,7 +368,9 @@ def model_generate(request: ModelGenerateRequest): assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found" def generate_steering_hook(steering: SteeringConfig): - def steering_hook(tensor: torch.Tensor, hook: HookPoint): + feature_acts = None + + def steer(tensor: torch.Tensor): assert len(tensor.shape) == 3 tensor = tensor.clone() if steering.steering_type == "times": @@ -384,14 +386,28 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint): tensor[:, :, steering.feature_index] = steering.steering_value return tensor + def save_feature_acts_hook(tensor: torch.Tensor, hook: HookPoint): + nonlocal feature_acts + feature_acts = tensor + return steer(tensor) + + def steering_hook(tensor: torch.Tensor, hook: HookPoint): + assert feature_acts is not None, "Feature acts should be saved before steering" + difference = (steer(feature_acts) - feature_acts) @ sae.decoder.weight.T + tensor += difference.detach() + return tensor + sae = get_sae(steering.sae) - return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook + return [ + (f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", save_feature_acts_hook), + (f"{sae.cfg.hook_point_out}", steering_hook), + ] - steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings] + steering_hooks = sum([generate_steering_hook(steering) for steering in request.steerings], []) with torch.no_grad(): with apply_sae(model, [sae for sae, _ in saes]): - with model.hooks(steerings_hooks): + with model.hooks(steering_hooks): input = ( model.to_tokens(request.input_text, prepend_bos=False) if isinstance(request.input_text, str) @@ -499,7 +515,10 @@ def model_trace(request: ModelTraceRequest): ), "Tracing SAE not found" def generate_steering_hook(steering: SteeringConfig): - def steering_hook(tensor: torch.Tensor, hook: HookPoint): + feature_acts = None + sae = get_sae(steering.sae) + + def steer(tensor: torch.Tensor): assert len(tensor.shape) == 3 tensor = tensor.clone() if steering.steering_type == "times": @@ -515,17 +534,31 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint): tensor[:, :, steering.feature_index] = steering.steering_value return tensor + def save_feature_acts_hook(tensor: torch.Tensor, hook: HookPoint): + nonlocal feature_acts + feature_acts = tensor + return steer(tensor) + + def steering_hook(tensor: torch.Tensor, hook: HookPoint): + assert feature_acts is not None, "Feature acts should be saved before steering" + difference = (steer(feature_acts) - feature_acts) @ sae.decoder.weight.T + tensor += difference.detach() + return tensor + sae = get_sae(steering.sae) - return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook + return [ + (f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", save_feature_acts_hook), + (f"{sae.cfg.hook_point_out}", steering_hook), + ] - steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings] + steering_hooks = sum([generate_steering_hook(steering) for steering in request.steerings], []) candidates = [f"{sae.cfg.hook_point_out}.sae.hook_feature_acts" for sae, _ in saes] if request.detach_at_attn_scores: candidates += [f"blocks.{i}.attn.hook_attn_scores" for i in range(model.cfg.n_layers)] with apply_sae(model, [sae for sae, _ in saes]): - with model.hooks(steerings_hooks): + with model.hooks(steering_hooks): with detach_at(model, candidates): input = ( model.to_tokens(request.input_text, prepend_bos=False) diff --git a/tests/intergration/test_attributor.py b/tests/intergration/test_attributor.py deleted file mode 100644 index 25790a1..0000000 --- a/tests/intergration/test_attributor.py +++ /dev/null @@ -1,63 +0,0 @@ -from math import isclose - -import torch -import torch.nn as nn -from transformer_lens.hook_points import HookedRootModule, HookPoint - -from lm_saes.circuit.attributors import DirectAttributor, HierachicalAttributor, Node - - -class TestModule(HookedRootModule): - __test__ = False - - def __init__(self): - super().__init__() - self.W_1 = nn.Parameter(torch.tensor([[1.0, 2.0]])) - self.W_2 = nn.Parameter(torch.tensor([[1.0], [1.0]])) - self.W_3 = nn.Parameter(torch.tensor([[1.0, 1.0]])) - self.W_4 = nn.Parameter(torch.tensor([[2.0], [1.0]])) - self.hook_mid_1 = HookPoint() - self.hook_mid_2 = HookPoint() - self.setup() - - def forward(self, input): - input = input + self.hook_mid_1(input @ self.W_1) @ self.W_2 - input = input + self.hook_mid_2(input @ self.W_3) @ self.W_4 - return input - - -def test_direct_attributor(): - model = TestModule() - attributor = DirectAttributor(model) - input = torch.tensor([1.0]) - input = input.requires_grad_() - circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")]) - assert len(circuit.nodes) == 5 - assert isclose(circuit.nodes[Node("hook_mid_1", "0")]["attribution"], 1.0) - assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 2.0) - assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.0) - assert isclose(circuit.nodes[Node("hook_mid_2", "1")]["attribution"], 4.0) - - -def test_hierachical_attributor(): - model = TestModule() - attributor = HierachicalAttributor(model) - input = torch.tensor([1.0]) - input = input.requires_grad_() - circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")]) - assert len(circuit.nodes) == 5 - assert isclose(circuit.nodes[Node("hook_mid_1", "0")]["attribution"], 4.0) - assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 8.0) - assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.0) - assert isclose(circuit.nodes[Node("hook_mid_2", "1")]["attribution"], 4.0) - - -def test_hierachical_attributor_with_threshold(): - model = TestModule() - attributor = HierachicalAttributor(model) - input = torch.tensor([1.0]) - input = input.requires_grad_() - circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")], threshold=5.0) - assert len(circuit.nodes) == 3 - assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 6.0) - assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.0)