-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from OpenMOSS/circuit
Add attribution support
- Loading branch information
Showing
9 changed files
with
530 additions
and
9 deletions.
There are no files selected for viewing
50 changes: 50 additions & 0 deletions
50
TransformerLens/tests/integration/test_mount_hooked_modules.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Block(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.subblock1 = nn.Linear(10, 10) | ||
self.subblock2 = nn.Linear(10, 10) | ||
self.activation = nn.ReLU() | ||
self.hook_pre = HookPoint() | ||
self.hook_mid = HookPoint() | ||
|
||
def forward(self, x): | ||
return self.subblock2(self.hook_mid(self.activation(self.subblock1(self.hook_pre(x))))) | ||
|
||
class TestModule(HookedRootModule): | ||
__test__ = False | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.blocks = nn.ModuleList([Block() for _ in range(3)]) | ||
self.embed = nn.Linear(1, 10) | ||
self.unembed = nn.Linear(10, 1) | ||
self.setup() | ||
|
||
def forward(self, x): | ||
x = self.embed(x) | ||
for block in self.blocks: | ||
x = block(x) | ||
return self.unembed(x) | ||
|
||
class TestMountModule(HookedRootModule): | ||
__test__ = False | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.hook_mid = HookPoint() | ||
self.setup() | ||
|
||
def forward(self, x): | ||
return self.hook_mid(x) * 2 | ||
|
||
def test_apply_hooked_modules(): | ||
model = TestModule() | ||
model_to_mount = TestMountModule() | ||
with model.mount_hooked_modules([("blocks.0.hook_mid", "m", model_to_mount)]): | ||
assert model.blocks[0].hook_mid.m == model_to_mount | ||
assert model_to_mount.hook_mid.name == "blocks.0.hook_mid.m.hook_mid" | ||
assert "blocks.0.hook_mid.m.hook_mid" in model.hook_dict | ||
assert not hasattr(model.blocks[0].hook_mid, "m") | ||
assert model_to_mount.hook_mid.name == "hook_mid" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from abc import ABC | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Tuple, Union | ||
import networkx as nx | ||
|
||
import torch | ||
from transformer_lens import HookedTransformer | ||
from transformer_lens.hook_points import HookPoint, HookedRootModule | ||
|
||
from lm_saes.utils.hooks import compose_hooks, detach_hook, retain_grad_hook | ||
from lm_saes.circuit.graph import Node | ||
|
||
class Cache: | ||
def __init__(self, output, cache: dict[str, torch.Tensor]): | ||
self.cache = cache | ||
self.output = output | ||
|
||
def tensor(self, node: Node) -> torch.Tensor: | ||
return node.reduce(self[node.hook_point]) | ||
|
||
def grad(self, node: Node) -> torch.Tensor | None: | ||
grad = self[node.hook_point].grad | ||
return node.reduce(grad) if grad is not None else None | ||
|
||
def __getitem__(self, key: Node | str | None) -> torch.Tensor: | ||
if isinstance(key, Node): | ||
return self.tensor(key) | ||
return self.cache[key] if key is not None else self.output | ||
|
||
class Attributor(ABC): | ||
def __init__( | ||
self, | ||
model: HookedRootModule | ||
): | ||
self.model = model | ||
|
||
def attribute( | ||
self, | ||
input: Any, | ||
target: Node, | ||
candidates: list[Node], | ||
**kwargs | ||
) -> nx.MultiDiGraph: | ||
""" | ||
Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input. | ||
Args: | ||
input (Any): The input to the model. | ||
target: The target node to attribute. | ||
candidates: The intermediate nodes to attribute to. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", | ||
showing its "importance" w.r.t. the target. | ||
""" | ||
raise NotImplementedError | ||
|
||
def cache_nodes( | ||
self, | ||
input: Any, | ||
nodes: list[Node], | ||
): | ||
""" | ||
Cache the activation of in the model forward pass. | ||
Args: | ||
input (Any): The input to the model. | ||
nodes (list[Node]): The nodes to cache. | ||
""" | ||
output, cache = self.model.run_with_ref_cache(input, names_filter=[node.hook_point for node in nodes]) | ||
return Cache(output, cache) | ||
|
||
|
||
class DirectAttributor(Attributor): | ||
def attribute( | ||
self, | ||
input: Any, | ||
target: Node, | ||
candidates: list[Node], | ||
**kwargs | ||
) -> nx.MultiDiGraph: | ||
""" | ||
Attribute the target node of the model to given candidates in the model, w.r.t. the given input. | ||
Args: | ||
input (Any): The input to the model. | ||
target: The target node to attribute. | ||
candidates: The intermediate nodes to attribute to. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", | ||
showing its "importance" w.r.t. the target. | ||
""" | ||
|
||
threshold: int = kwargs.get("threshold", 0.1) | ||
|
||
fwd_hooks = [(candidate.hook_point, detach_hook) for candidate in candidates if candidate.hook_point != target.hook_point] | ||
with self.model.hooks(fwd_hooks): | ||
cache = self.cache_nodes(input, candidates + [target]) | ||
cache[target].backward() | ||
|
||
# Construct the circuit | ||
circuit = nx.MultiDiGraph() | ||
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item()) | ||
for candidate in candidates: | ||
if candidate.hook_point == target.hook_point: | ||
continue | ||
grad = cache.grad(candidate) | ||
if grad is None: | ||
continue | ||
attributions = grad * cache[candidate] | ||
if len(attributions.shape) == 0: | ||
if attributions > threshold: | ||
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item()) | ||
circuit.add_edge(candidate, target, attribution=attributions.item(), direct_attribution=attributions.item()) | ||
else: | ||
for index in (attributions > threshold).nonzero(): | ||
index = tuple(index.tolist()) | ||
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item()) | ||
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item(), direct_attribution=attributions[index].item()) | ||
return circuit | ||
|
||
|
||
class HierachicalAttributor(Attributor): | ||
def attribute( | ||
self, | ||
input: Any, | ||
target: Node, | ||
candidates: list[Node], | ||
**kwargs | ||
) -> nx.MultiDiGraph: | ||
""" | ||
Attribute the target node of the model to given candidates in the model, w.r.t. the given input. | ||
Args: | ||
input (Any): The input to the model. | ||
target: The target node to attribute. | ||
candidates: The intermediate nodes to attribute to. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution", | ||
showing its "importance" w.r.t. the target. | ||
""" | ||
|
||
threshold: int = kwargs.get("threshold", 0.1) | ||
|
||
def generate_attribution_score_filter_hook(): | ||
v = None | ||
def fwd_hook(tensor: torch.Tensor, hook: HookPoint): | ||
nonlocal v | ||
v = tensor | ||
return tensor | ||
def attribution_score_filter_hook(grad: torch.Tensor, hook: HookPoint): | ||
assert v is not None, "fwd_hook must be called before attribution_score_filter_hook." | ||
return (torch.where(v * grad > threshold, grad, torch.zeros_like(grad)),) | ||
return fwd_hook, attribution_score_filter_hook | ||
attribution_score_filter_hooks = {candidate: generate_attribution_score_filter_hook() for candidate in candidates} | ||
fwd_hooks = [(candidate.hook_point, compose_hooks(attribution_score_filter_hooks[candidate][0], retain_grad_hook)) for candidate in candidates] | ||
with self.model.hooks( | ||
fwd_hooks=fwd_hooks, | ||
bwd_hooks=[(candidate.hook_point, attribution_score_filter_hooks[candidate][1]) for candidate in candidates] | ||
): | ||
cache = self.cache_nodes(input, candidates + [target]) | ||
cache[target].backward() | ||
|
||
# Construct the circuit | ||
circuit = nx.MultiDiGraph() | ||
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item()) | ||
for candidate in candidates: | ||
grad = cache.grad(candidate) | ||
if grad is None: | ||
continue | ||
attributions = grad * cache[candidate] | ||
if len(attributions.shape) == 0: | ||
if attributions > threshold: | ||
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item()) | ||
else: | ||
for index in (attributions > threshold).nonzero(): | ||
index = tuple(index.tolist()) | ||
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item()) | ||
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item()) | ||
|
||
return circuit | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from contextlib import contextmanager | ||
from typing import Callable, Tuple, Union | ||
import torch | ||
from transformer_lens.hook_points import HookPoint, HookedRootModule | ||
|
||
from lm_saes.sae import SparseAutoEncoder | ||
|
||
@contextmanager | ||
def apply_sae( | ||
model: HookedRootModule, | ||
saes: list[SparseAutoEncoder] | ||
): | ||
""" | ||
Apply the sparse autoencoders to the model. | ||
""" | ||
fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = [] | ||
def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]: | ||
if sae.cfg.hook_point_in == sae.cfg.hook_point_out: | ||
def hook(tensor: torch.Tensor, hook: HookPoint): | ||
reconstructed = sae.forward(tensor) | ||
return reconstructed + (tensor - reconstructed).detach() | ||
return [(sae.cfg.hook_point_in, hook)] | ||
else: | ||
x = None | ||
def hook_in(tensor: torch.Tensor, hook: HookPoint): | ||
nonlocal x | ||
x = tensor | ||
return tensor | ||
def hook_out(tensor: torch.Tensor, hook: HookPoint): | ||
nonlocal x | ||
assert x is not None, "hook_in must be called before hook_out." | ||
reconstructed = sae.forward(x, label=tensor) | ||
x = None | ||
return reconstructed + (tensor - reconstructed).detach() | ||
return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)] | ||
for sae in saes: | ||
hooks = get_fwd_hooks(sae) | ||
fwd_hooks.extend(hooks) | ||
with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]): | ||
with model.hooks(fwd_hooks): | ||
yield model |
Oops, something went wrong.