diff --git a/TransformerLens/tests/integration/test_mount_hooked_modules.py b/TransformerLens/tests/integration/test_mount_hooked_modules.py new file mode 100644 index 0000000..e097cfd --- /dev/null +++ b/TransformerLens/tests/integration/test_mount_hooked_modules.py @@ -0,0 +1,50 @@ +from transformer_lens.hook_points import HookedRootModule, HookPoint +import torch +import torch.nn as nn + +class Block(nn.Module): + def __init__(self): + super().__init__() + self.subblock1 = nn.Linear(10, 10) + self.subblock2 = nn.Linear(10, 10) + self.activation = nn.ReLU() + self.hook_pre = HookPoint() + self.hook_mid = HookPoint() + + def forward(self, x): + return self.subblock2(self.hook_mid(self.activation(self.subblock1(self.hook_pre(x))))) + +class TestModule(HookedRootModule): + __test__ = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.blocks = nn.ModuleList([Block() for _ in range(3)]) + self.embed = nn.Linear(1, 10) + self.unembed = nn.Linear(10, 1) + self.setup() + + def forward(self, x): + x = self.embed(x) + for block in self.blocks: + x = block(x) + return self.unembed(x) + +class TestMountModule(HookedRootModule): + __test__ = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hook_mid = HookPoint() + self.setup() + + def forward(self, x): + return self.hook_mid(x) * 2 + +def test_apply_hooked_modules(): + model = TestModule() + model_to_mount = TestMountModule() + with model.mount_hooked_modules([("blocks.0.hook_mid", "m", model_to_mount)]): + assert model.blocks[0].hook_mid.m == model_to_mount + assert model_to_mount.hook_mid.name == "blocks.0.hook_mid.m.hook_mid" + assert "blocks.0.hook_mid.m.hook_mid" in model.hook_dict + assert not hasattr(model.blocks[0].hook_mid, "m") + assert model_to_mount.hook_mid.name == "hook_mid" \ No newline at end of file diff --git a/TransformerLens/transformer_lens/hook_points.py b/TransformerLens/transformer_lens/hook_points.py index 8992028..d6d097a 100644 --- a/TransformerLens/transformer_lens/hook_points.py +++ b/TransformerLens/transformer_lens/hook_points.py @@ -116,7 +116,7 @@ def full_hook( _internal_hooks = self._forward_hooks visible_hooks = self.fwd_hooks elif dir == "bwd": - pt_handle = self.register_backward_hook(full_hook) + pt_handle = self.register_full_backward_hook(full_hook) _internal_hooks = self._backward_hooks visible_hooks = self.bwd_hooks else: @@ -696,7 +696,7 @@ def get_ref_caching_hooks( names_filter: NamesFilter = None, retain_grad: bool = False, cache: Optional[dict] = None, - ) -> Tuple[dict, list, list]: + ) -> Tuple[dict, list]: """Creates hooks to keep references to activations. Note: It does not add the hooks to the model. Args: @@ -788,13 +788,13 @@ def offload_params_after(self, last_hook: str, *model_args, **model_kwargs): **model_kwargs: Keyword arguments for the model. """ pass_module_list: List[nn.Module] = [] - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) hook_handles: List[hooks.RemovableHandle] = [] def pass_hook(module: nn.Module, module_input: Any, module_output: Any): pass_module_list.append(module) def convert_hook(tensor: torch.Tensor, hook: HookPoint): + assert hook.name is not None # Make mypy happy pass_param_set = set() hook_ancestors = [module for module in self.modules() if module == self or hook.name.startswith(module.name)] for module in pass_module_list + hook_ancestors: @@ -813,12 +813,11 @@ def convert_hook(tensor: torch.Tensor, hook: HookPoint): for _, module in self.named_modules(): hook_handles.append(module.register_forward_hook(pass_hook)) - with fake_mode: - with self.hooks(fwd_hooks=[(last_hook, convert_hook)]): - try: - self(*model_args, **model_kwargs) - except StopIteration: - pass + with self.hooks(fwd_hooks=[(last_hook, convert_hook)]): + try: + self(*model_args, **model_kwargs) + except StopIteration: + pass for handle in hook_handles: handle.remove() @@ -895,6 +894,47 @@ def stop_hook(tensor: torch.Tensor, hook: HookPoint): model_out = e.tensor return model_out, cache_dict + + @contextmanager + def mount_hooked_modules( + self, + hooked_modules: List[Tuple[str, str, "HookedRootModule"]], + ): + """ + A context manager for adding child hooked modules at specified hook points. + + Args: + hooked_modules: List[Tuple[name, module_name, module]], where name is the name of a + hook point, module_name is the name of the module to add (which will be used to + mount the module inside the hook point), and module is the module instance to add. + A filter function as name is not allowed, since unexpected behavior may occur when + the same module is mounted at multiple hook points. + """ + + for name, module_name, module in hooked_modules: + hook_point = self.mod_dict[name] + assert isinstance( + hook_point, HookPoint + ) + hook_point.add_module(module_name, module) + + self.setup() + + try: + yield self + finally: + for name, module_name, module in hooked_modules: + if isinstance(name, str): + hook_point = self.mod_dict[name] + delattr(hook_point, module_name) + else: + for hook_point_name, hp in self.hook_dict.items(): + if name(hook_point_name): + delattr(hp, module_name) + module.setup() + self.setup() + + # %% diff --git a/src/lm_saes/circuit/__init__.py b/src/lm_saes/circuit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lm_saes/circuit/attributors.py b/src/lm_saes/circuit/attributors.py new file mode 100644 index 0000000..e56cf97 --- /dev/null +++ b/src/lm_saes/circuit/attributors.py @@ -0,0 +1,188 @@ +from abc import ABC +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Tuple, Union +import networkx as nx + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint, HookedRootModule + +from lm_saes.utils.hooks import compose_hooks, detach_hook, retain_grad_hook +from lm_saes.circuit.graph import Node + +class Cache: + def __init__(self, output, cache: dict[str, torch.Tensor]): + self.cache = cache + self.output = output + + def tensor(self, node: Node) -> torch.Tensor: + return node.reduce(self[node.hook_point]) + + def grad(self, node: Node) -> torch.Tensor | None: + grad = self[node.hook_point].grad + return node.reduce(grad) if grad is not None else None + + def __getitem__(self, key: Node | str | None) -> torch.Tensor: + if isinstance(key, Node): + return self.tensor(key) + return self.cache[key] if key is not None else self.output + +class Attributor(ABC): + def __init__( + self, + model: HookedRootModule + ): + self.model = model + + def attribute( + self, + input: Any, + target: Node, + candidates: list[Node], + **kwargs + ) -> nx.MultiDiGraph: + """ + Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input. + + Args: + input (Any): The input to the model. + target: The target node to attribute. + candidates: The intermediate nodes to attribute to. + **kwargs: Additional keyword arguments. + + Returns: + nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", + showing its "importance" w.r.t. the target. + """ + raise NotImplementedError + + def cache_nodes( + self, + input: Any, + nodes: list[Node], + ): + """ + Cache the activation of in the model forward pass. + + Args: + input (Any): The input to the model. + nodes (list[Node]): The nodes to cache. + """ + output, cache = self.model.run_with_ref_cache(input, names_filter=[node.hook_point for node in nodes]) + return Cache(output, cache) + + +class DirectAttributor(Attributor): + def attribute( + self, + input: Any, + target: Node, + candidates: list[Node], + **kwargs + ) -> nx.MultiDiGraph: + """ + Attribute the target node of the model to given candidates in the model, w.r.t. the given input. + + Args: + input (Any): The input to the model. + target: The target node to attribute. + candidates: The intermediate nodes to attribute to. + **kwargs: Additional keyword arguments. + + Returns: + nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", + showing its "importance" w.r.t. the target. + """ + + threshold: int = kwargs.get("threshold", 0.1) + + fwd_hooks = [(candidate.hook_point, detach_hook) for candidate in candidates if candidate.hook_point != target.hook_point] + with self.model.hooks(fwd_hooks): + cache = self.cache_nodes(input, candidates + [target]) + cache[target].backward() + + # Construct the circuit + circuit = nx.MultiDiGraph() + circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item()) + for candidate in candidates: + if candidate.hook_point == target.hook_point: + continue + grad = cache.grad(candidate) + if grad is None: + continue + attributions = grad * cache[candidate] + if len(attributions.shape) == 0: + if attributions > threshold: + circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item()) + circuit.add_edge(candidate, target, attribution=attributions.item(), direct_attribution=attributions.item()) + else: + for index in (attributions > threshold).nonzero(): + index = tuple(index.tolist()) + circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item()) + circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item(), direct_attribution=attributions[index].item()) + return circuit + + +class HierachicalAttributor(Attributor): + def attribute( + self, + input: Any, + target: Node, + candidates: list[Node], + **kwargs + ) -> nx.MultiDiGraph: + """ + Attribute the target node of the model to given candidates in the model, w.r.t. the given input. + + Args: + input (Any): The input to the model. + target: The target node to attribute. + candidates: The intermediate nodes to attribute to. + **kwargs: Additional keyword arguments. + + Returns: + nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", + showing its "importance" w.r.t. the target. + """ + + threshold: int = kwargs.get("threshold", 0.1) + + def generate_attribution_score_filter_hook(): + v = None + def fwd_hook(tensor: torch.Tensor, hook: HookPoint): + nonlocal v + v = tensor + return tensor + def attribution_score_filter_hook(grad: torch.Tensor, hook: HookPoint): + assert v is not None, "fwd_hook must be called before attribution_score_filter_hook." + return (torch.where(v * grad > threshold, grad, torch.zeros_like(grad)),) + return fwd_hook, attribution_score_filter_hook + attribution_score_filter_hooks = {candidate: generate_attribution_score_filter_hook() for candidate in candidates} + fwd_hooks = [(candidate.hook_point, compose_hooks(attribution_score_filter_hooks[candidate][0], retain_grad_hook)) for candidate in candidates] + with self.model.hooks( + fwd_hooks=fwd_hooks, + bwd_hooks=[(candidate.hook_point, attribution_score_filter_hooks[candidate][1]) for candidate in candidates] + ): + cache = self.cache_nodes(input, candidates + [target]) + cache[target].backward() + + # Construct the circuit + circuit = nx.MultiDiGraph() + circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item()) + for candidate in candidates: + grad = cache.grad(candidate) + if grad is None: + continue + attributions = grad * cache[candidate] + if len(attributions.shape) == 0: + if attributions > threshold: + circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item()) + else: + for index in (attributions > threshold).nonzero(): + index = tuple(index.tolist()) + circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item()) + circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item()) + + return circuit + diff --git a/src/lm_saes/circuit/context.py b/src/lm_saes/circuit/context.py new file mode 100644 index 0000000..d526487 --- /dev/null +++ b/src/lm_saes/circuit/context.py @@ -0,0 +1,41 @@ +from contextlib import contextmanager +from typing import Callable, Tuple, Union +import torch +from transformer_lens.hook_points import HookPoint, HookedRootModule + +from lm_saes.sae import SparseAutoEncoder + +@contextmanager +def apply_sae( + model: HookedRootModule, + saes: list[SparseAutoEncoder] +): + """ + Apply the sparse autoencoders to the model. + """ + fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = [] + def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]: + if sae.cfg.hook_point_in == sae.cfg.hook_point_out: + def hook(tensor: torch.Tensor, hook: HookPoint): + reconstructed = sae.forward(tensor) + return reconstructed + (tensor - reconstructed).detach() + return [(sae.cfg.hook_point_in, hook)] + else: + x = None + def hook_in(tensor: torch.Tensor, hook: HookPoint): + nonlocal x + x = tensor + return tensor + def hook_out(tensor: torch.Tensor, hook: HookPoint): + nonlocal x + assert x is not None, "hook_in must be called before hook_out." + reconstructed = sae.forward(x, label=tensor) + x = None + return reconstructed + (tensor - reconstructed).detach() + return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)] + for sae in saes: + hooks = get_fwd_hooks(sae) + fwd_hooks.extend(hooks) + with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]): + with model.hooks(fwd_hooks): + yield model \ No newline at end of file diff --git a/src/lm_saes/circuit/graph.py b/src/lm_saes/circuit/graph.py new file mode 100644 index 0000000..e27a78d --- /dev/null +++ b/src/lm_saes/circuit/graph.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Node: + """ + A node in the circuit. + """ + + hook_point: str | None + """ The hook point of the node. None means the node is the output of the model. """ + reduction: str | None = None + """ The reduction function to apply to the node. """ + + def reduce(self, tensor: torch.Tensor) -> torch.Tensor: + reductions = self.reduction.split(".") if self.reduction is not None else [] + for reduction in reductions: + if reduction == "max": + tensor = tensor.max() + elif reduction == "mean": + tensor = tensor.mean() + elif reduction == "sum": + tensor = tensor.sum() + else: + try: + index = int(reduction) + tensor = tensor[index] + except ValueError: + raise ValueError(f"Unknown reduction function: {reduction} in {self.reduction}.") + return tensor + + def append_reduction(self, *reduction: list[str | int]) -> "Node": + reduction: str = ".".join(map(str, reduction)) + return Node(self.hook_point, f"{self.reduction}.{reduction}" if self.reduction is not None else reduction) + + def __hash__(self): + return hash((self.hook_point, self.reduction)) + + def __eq__(self, other): + if not isinstance(other, Node): + return False + return self.hook_point == other.hook_point and self.reduction == other.reduction + + def __str__(self) -> str: + hook_point = self.hook_point if self.hook_point is not None else "output" + return f"{hook_point}.{self.reduction}" if self.reduction is not None else hook_point \ No newline at end of file diff --git a/src/lm_saes/circuit/transformer.py b/src/lm_saes/circuit/transformer.py new file mode 100644 index 0000000..e065e77 --- /dev/null +++ b/src/lm_saes/circuit/transformer.py @@ -0,0 +1,74 @@ +from typing import Any +from transformer_lens import HookedTransformer +from lm_saes.circuit.attributors import DirectAttributor, HierachicalAttributor +from lm_saes.circuit.context import apply_sae +from lm_saes.circuit.graph import Node +from lm_saes.sae import SparseAutoEncoder +from lm_saes.utils.hooks import detach_hook + + +def direct_attribute_transformer_with_saes( + model: HookedTransformer, + saes: list[SparseAutoEncoder], + input: Any, + target: Node, + candidates: list[Node] | None = None, + threshold: float = 0.1, +): + """ + Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input. + This attribution will only consider the direct connections between the target and the candidates, but not + indirect effects through intermediate candidates. + + Args: + model (HookedTransformer): The model to attribute. + saes (list[SparseAutoEncoder]): The sparse autoencoders to apply. + input (Any): The input to the model. + target: The target node to attribute. + candidates: The intermediate nodes to attribute to. If None, default to all sae feature activations and all attention scores. + threshold (float): The threshold to prune the circuit. + + Returns: + nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", + showing its "importance" w.r.t. the target. + """ + + with apply_sae(model, saes): + with model.hooks([(f"blocks.{i}.attn.hook_attn_scores", detach_hook) for i in range(12)]): + attributor = DirectAttributor(model) + if candidates is None: + candidates = [Node(f"{sae.cfg.hook_point_out}.sae.hook_feature_acts") for sae in saes] + [Node(f"blocks.{i}.attn.hook_attn_scores") for i in range(12)] + return attributor.attribute(input=input, target=target, candidates=candidates, threshold=threshold) + +def hierarchical_attribute_transformer_with_saes( + model: HookedTransformer, + saes: list[SparseAutoEncoder], + input: Any, + target: Node, + candidates: list[Node] | None = None, + threshold: float = 0.1, +): + """ + Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input. + This attribution will consider both the direct connections between the target and the candidates, and + indirect effects through intermediate candidates. + + Args: + model (HookedTransformer): The model to attribute. + saes (list[SparseAutoEncoder]): The sparse autoencoders to apply. + input (Any): The input to the model. + target: The target node to attribute. + candidates: The intermediate nodes to attribute to. If None, default to all sae feature activations and all attention scores. + threshold (float): The threshold to prune the circuit. + + Returns: + nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", + showing its "importance" w.r.t. the target. + """ + + with apply_sae(model, saes): + with model.hooks([(f"blocks.{i}.attn.hook_attn_scores", detach_hook) for i in range(12)]): + attributor = HierachicalAttributor(model) + if candidates is None: + candidates = [Node(f"{sae.cfg.hook_point_out}.sae.hook_feature_acts") for sae in saes] + [Node(f"blocks.{i}.attn.hook_attn_scores") for i in range(12)] + return attributor.attribute(input=input, target=target, candidates=candidates, threshold=threshold) \ No newline at end of file diff --git a/src/lm_saes/utils/hooks.py b/src/lm_saes/utils/hooks.py new file mode 100644 index 0000000..9c71a6e --- /dev/null +++ b/src/lm_saes/utils/hooks.py @@ -0,0 +1,25 @@ +import torch +from transformer_lens.hook_points import HookPoint + +def compose_hooks(*hooks): + """ + Compose multiple hooks into a single hook by executing them in order. + """ + def composed_hook(tensor: torch.Tensor, hook: HookPoint): + for hook_fn in hooks: + tensor = hook_fn(tensor, hook) + return tensor + return composed_hook + +def retain_grad_hook(tensor: torch.Tensor, hook: HookPoint): + """ + Retain the gradient of the tensor at the given hook point. + """ + tensor.retain_grad() + return tensor + +def detach_hook(tensor: torch.Tensor, hook: HookPoint): + """ + Detach the tensor at the given hook point. + """ + return tensor.detach().requires_grad_(True) \ No newline at end of file diff --git a/tests/intergration/test_attributor.py b/tests/intergration/test_attributor.py new file mode 100644 index 0000000..1b102e5 --- /dev/null +++ b/tests/intergration/test_attributor.py @@ -0,0 +1,55 @@ +from math import isclose +from transformer_lens.hook_points import HookPoint, HookedRootModule +import torch +import torch.nn as nn +from lm_saes.circuit.attributors import DirectAttributor, HierachicalAttributor, Node + +class TestModule(HookedRootModule): + def __init__(self): + super().__init__() + self.W_1 = nn.Parameter(torch.tensor([[1., 2.]])) + self.W_2 = nn.Parameter(torch.tensor([[1.], [1.]])) + self.W_3 = nn.Parameter(torch.tensor([[1., 1.]])) + self.W_4 = nn.Parameter(torch.tensor([[2.], [1.]])) + self.hook_mid_1 = HookPoint() + self.hook_mid_2 = HookPoint() + self.setup() + + def forward(self, input): + input = input + self.hook_mid_1(input @ self.W_1) @ self.W_2 + input = input + self.hook_mid_2(input @ self.W_3) @ self.W_4 + return input + +def test_direct_attributor(): + model = TestModule() + attributor = DirectAttributor(model) + input = torch.tensor([1.]) + input = input.requires_grad_() + circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")]) + assert len(circuit.nodes) == 5 + assert isclose(circuit.nodes[Node("hook_mid_1", "0")]["attribution"], 1.) + assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 2.) + assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.) + assert isclose(circuit.nodes[Node("hook_mid_2", "1")]["attribution"], 4.) + +def test_hierachical_attributor(): + model = TestModule() + attributor = HierachicalAttributor(model) + input = torch.tensor([1.]) + input = input.requires_grad_() + circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")]) + assert len(circuit.nodes) == 5 + assert isclose(circuit.nodes[Node("hook_mid_1", "0")]["attribution"], 4.) + assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 8.) + assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.) + assert isclose(circuit.nodes[Node("hook_mid_2", "1")]["attribution"], 4.) + +def test_hierachical_attributor_with_threshold(): + model = TestModule() + attributor = HierachicalAttributor(model) + input = torch.tensor([1.]) + input = input.requires_grad_() + circuit = attributor.attribute(input, Node(None, "0"), [Node("hook_mid_1"), Node("hook_mid_2")], threshold=5.) + assert len(circuit.nodes) == 3 + assert isclose(circuit.nodes[Node("hook_mid_1", "1")]["attribution"], 6.) + assert isclose(circuit.nodes[Node("hook_mid_2", "0")]["attribution"], 8.) \ No newline at end of file