Skip to content

Commit

Permalink
Merge pull request #23 from OpenMOSS/circuit
Browse files Browse the repository at this point in the history
Add attribution support
  • Loading branch information
dest1n1s authored Jun 20, 2024
2 parents 997e308 + eb95797 commit 74ef9dd
Show file tree
Hide file tree
Showing 9 changed files with 530 additions and 9 deletions.
50 changes: 50 additions & 0 deletions TransformerLens/tests/integration/test_mount_hooked_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint
import torch
import torch.nn as nn

class Block(nn.Module):
def __init__(self):
super().__init__()
self.subblock1 = nn.Linear(10, 10)
self.subblock2 = nn.Linear(10, 10)
self.activation = nn.ReLU()
self.hook_pre = HookPoint()
self.hook_mid = HookPoint()

def forward(self, x):
return self.subblock2(self.hook_mid(self.activation(self.subblock1(self.hook_pre(x)))))

class TestModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([Block() for _ in range(3)])
self.embed = nn.Linear(1, 10)
self.unembed = nn.Linear(10, 1)
self.setup()

def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.unembed(x)

class TestMountModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hook_mid = HookPoint()
self.setup()

def forward(self, x):
return self.hook_mid(x) * 2

def test_apply_hooked_modules():
model = TestModule()
model_to_mount = TestMountModule()
with model.mount_hooked_modules([("blocks.0.hook_mid", "m", model_to_mount)]):
assert model.blocks[0].hook_mid.m == model_to_mount
assert model_to_mount.hook_mid.name == "blocks.0.hook_mid.m.hook_mid"
assert "blocks.0.hook_mid.m.hook_mid" in model.hook_dict
assert not hasattr(model.blocks[0].hook_mid, "m")
assert model_to_mount.hook_mid.name == "hook_mid"
58 changes: 49 additions & 9 deletions TransformerLens/transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def full_hook(
_internal_hooks = self._forward_hooks
visible_hooks = self.fwd_hooks
elif dir == "bwd":
pt_handle = self.register_backward_hook(full_hook)
pt_handle = self.register_full_backward_hook(full_hook)
_internal_hooks = self._backward_hooks
visible_hooks = self.bwd_hooks
else:
Expand Down Expand Up @@ -696,7 +696,7 @@ def get_ref_caching_hooks(
names_filter: NamesFilter = None,
retain_grad: bool = False,
cache: Optional[dict] = None,
) -> Tuple[dict, list, list]:
) -> Tuple[dict, list]:
"""Creates hooks to keep references to activations. Note: It does not add the hooks to the model.
Args:
Expand Down Expand Up @@ -788,13 +788,13 @@ def offload_params_after(self, last_hook: str, *model_args, **model_kwargs):
**model_kwargs: Keyword arguments for the model.
"""
pass_module_list: List[nn.Module] = []
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
hook_handles: List[hooks.RemovableHandle] = []

def pass_hook(module: nn.Module, module_input: Any, module_output: Any):
pass_module_list.append(module)

def convert_hook(tensor: torch.Tensor, hook: HookPoint):
assert hook.name is not None # Make mypy happy
pass_param_set = set()
hook_ancestors = [module for module in self.modules() if module == self or hook.name.startswith(module.name)]
for module in pass_module_list + hook_ancestors:
Expand All @@ -813,12 +813,11 @@ def convert_hook(tensor: torch.Tensor, hook: HookPoint):
for _, module in self.named_modules():
hook_handles.append(module.register_forward_hook(pass_hook))

with fake_mode:
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass

for handle in hook_handles:
handle.remove()
Expand Down Expand Up @@ -895,6 +894,47 @@ def stop_hook(tensor: torch.Tensor, hook: HookPoint):
model_out = e.tensor

return model_out, cache_dict

@contextmanager
def mount_hooked_modules(
self,
hooked_modules: List[Tuple[str, str, "HookedRootModule"]],
):
"""
A context manager for adding child hooked modules at specified hook points.
Args:
hooked_modules: List[Tuple[name, module_name, module]], where name is the name of a
hook point, module_name is the name of the module to add (which will be used to
mount the module inside the hook point), and module is the module instance to add.
A filter function as name is not allowed, since unexpected behavior may occur when
the same module is mounted at multiple hook points.
"""

for name, module_name, module in hooked_modules:
hook_point = self.mod_dict[name]
assert isinstance(
hook_point, HookPoint
)
hook_point.add_module(module_name, module)

self.setup()

try:
yield self
finally:
for name, module_name, module in hooked_modules:
if isinstance(name, str):
hook_point = self.mod_dict[name]
delattr(hook_point, module_name)
else:
for hook_point_name, hp in self.hook_dict.items():
if name(hook_point_name):
delattr(hp, module_name)
module.setup()
self.setup()




# %%
Empty file added src/lm_saes/circuit/__init__.py
Empty file.
188 changes: 188 additions & 0 deletions src/lm_saes/circuit/attributors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from abc import ABC
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Tuple, Union
import networkx as nx

import torch
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint, HookedRootModule

from lm_saes.utils.hooks import compose_hooks, detach_hook, retain_grad_hook
from lm_saes.circuit.graph import Node

class Cache:
def __init__(self, output, cache: dict[str, torch.Tensor]):
self.cache = cache
self.output = output

def tensor(self, node: Node) -> torch.Tensor:
return node.reduce(self[node.hook_point])

def grad(self, node: Node) -> torch.Tensor | None:
grad = self[node.hook_point].grad
return node.reduce(grad) if grad is not None else None

def __getitem__(self, key: Node | str | None) -> torch.Tensor:
if isinstance(key, Node):
return self.tensor(key)
return self.cache[key] if key is not None else self.output

class Attributor(ABC):
def __init__(
self,
model: HookedRootModule
):
self.model = model

def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input.
Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.
Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""
raise NotImplementedError

def cache_nodes(
self,
input: Any,
nodes: list[Node],
):
"""
Cache the activation of in the model forward pass.
Args:
input (Any): The input to the model.
nodes (list[Node]): The nodes to cache.
"""
output, cache = self.model.run_with_ref_cache(input, names_filter=[node.hook_point for node in nodes])
return Cache(output, cache)


class DirectAttributor(Attributor):
def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target node of the model to given candidates in the model, w.r.t. the given input.
Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.
Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""

threshold: int = kwargs.get("threshold", 0.1)

fwd_hooks = [(candidate.hook_point, detach_hook) for candidate in candidates if candidate.hook_point != target.hook_point]
with self.model.hooks(fwd_hooks):
cache = self.cache_nodes(input, candidates + [target])
cache[target].backward()

# Construct the circuit
circuit = nx.MultiDiGraph()
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item())
for candidate in candidates:
if candidate.hook_point == target.hook_point:
continue
grad = cache.grad(candidate)
if grad is None:
continue
attributions = grad * cache[candidate]
if len(attributions.shape) == 0:
if attributions > threshold:
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item())
circuit.add_edge(candidate, target, attribution=attributions.item(), direct_attribution=attributions.item())
else:
for index in (attributions > threshold).nonzero():
index = tuple(index.tolist())
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item())
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item(), direct_attribution=attributions[index].item())
return circuit


class HierachicalAttributor(Attributor):
def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target node of the model to given candidates in the model, w.r.t. the given input.
Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.
Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""

threshold: int = kwargs.get("threshold", 0.1)

def generate_attribution_score_filter_hook():
v = None
def fwd_hook(tensor: torch.Tensor, hook: HookPoint):
nonlocal v
v = tensor
return tensor
def attribution_score_filter_hook(grad: torch.Tensor, hook: HookPoint):
assert v is not None, "fwd_hook must be called before attribution_score_filter_hook."
return (torch.where(v * grad > threshold, grad, torch.zeros_like(grad)),)
return fwd_hook, attribution_score_filter_hook
attribution_score_filter_hooks = {candidate: generate_attribution_score_filter_hook() for candidate in candidates}
fwd_hooks = [(candidate.hook_point, compose_hooks(attribution_score_filter_hooks[candidate][0], retain_grad_hook)) for candidate in candidates]
with self.model.hooks(
fwd_hooks=fwd_hooks,
bwd_hooks=[(candidate.hook_point, attribution_score_filter_hooks[candidate][1]) for candidate in candidates]
):
cache = self.cache_nodes(input, candidates + [target])
cache[target].backward()

# Construct the circuit
circuit = nx.MultiDiGraph()
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item())
for candidate in candidates:
grad = cache.grad(candidate)
if grad is None:
continue
attributions = grad * cache[candidate]
if len(attributions.shape) == 0:
if attributions > threshold:
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item())
else:
for index in (attributions > threshold).nonzero():
index = tuple(index.tolist())
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item())
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item())

return circuit

41 changes: 41 additions & 0 deletions src/lm_saes/circuit/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from contextlib import contextmanager
from typing import Callable, Tuple, Union
import torch
from transformer_lens.hook_points import HookPoint, HookedRootModule

from lm_saes.sae import SparseAutoEncoder

@contextmanager
def apply_sae(
model: HookedRootModule,
saes: list[SparseAutoEncoder]
):
"""
Apply the sparse autoencoders to the model.
"""
fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = []
def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]:
if sae.cfg.hook_point_in == sae.cfg.hook_point_out:
def hook(tensor: torch.Tensor, hook: HookPoint):
reconstructed = sae.forward(tensor)
return reconstructed + (tensor - reconstructed).detach()
return [(sae.cfg.hook_point_in, hook)]
else:
x = None
def hook_in(tensor: torch.Tensor, hook: HookPoint):
nonlocal x
x = tensor
return tensor
def hook_out(tensor: torch.Tensor, hook: HookPoint):
nonlocal x
assert x is not None, "hook_in must be called before hook_out."
reconstructed = sae.forward(x, label=tensor)
x = None
return reconstructed + (tensor - reconstructed).detach()
return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)]
for sae in saes:
hooks = get_fwd_hooks(sae)
fwd_hooks.extend(hooks)
with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]):
with model.hooks(fwd_hooks):
yield model
Loading

0 comments on commit 74ef9dd

Please sign in to comment.