Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attribution support #23

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions TransformerLens/tests/integration/test_mount_hooked_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint
import torch
import torch.nn as nn

class Block(nn.Module):
def __init__(self):
super().__init__()
self.subblock1 = nn.Linear(10, 10)
self.subblock2 = nn.Linear(10, 10)
self.activation = nn.ReLU()
self.hook_pre = HookPoint()
self.hook_mid = HookPoint()

def forward(self, x):
return self.subblock2(self.hook_mid(self.activation(self.subblock1(self.hook_pre(x)))))

class TestModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([Block() for _ in range(3)])
self.embed = nn.Linear(1, 10)
self.unembed = nn.Linear(10, 1)
self.setup()

def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.unembed(x)

class TestMountModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hook_mid = HookPoint()
self.setup()

def forward(self, x):
return self.hook_mid(x) * 2

def test_apply_hooked_modules():
model = TestModule()
model_to_mount = TestMountModule()
with model.mount_hooked_modules([("blocks.0.hook_mid", "m", model_to_mount)]):
assert model.blocks[0].hook_mid.m == model_to_mount
assert model_to_mount.hook_mid.name == "blocks.0.hook_mid.m.hook_mid"
assert "blocks.0.hook_mid.m.hook_mid" in model.hook_dict
assert not hasattr(model.blocks[0].hook_mid, "m")
assert model_to_mount.hook_mid.name == "hook_mid"
58 changes: 49 additions & 9 deletions TransformerLens/transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def full_hook(
_internal_hooks = self._forward_hooks
visible_hooks = self.fwd_hooks
elif dir == "bwd":
pt_handle = self.register_backward_hook(full_hook)
pt_handle = self.register_full_backward_hook(full_hook)
_internal_hooks = self._backward_hooks
visible_hooks = self.bwd_hooks
else:
Expand Down Expand Up @@ -696,7 +696,7 @@ def get_ref_caching_hooks(
names_filter: NamesFilter = None,
retain_grad: bool = False,
cache: Optional[dict] = None,
) -> Tuple[dict, list, list]:
) -> Tuple[dict, list]:
"""Creates hooks to keep references to activations. Note: It does not add the hooks to the model.

Args:
Expand Down Expand Up @@ -788,13 +788,13 @@ def offload_params_after(self, last_hook: str, *model_args, **model_kwargs):
**model_kwargs: Keyword arguments for the model.
"""
pass_module_list: List[nn.Module] = []
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
hook_handles: List[hooks.RemovableHandle] = []

def pass_hook(module: nn.Module, module_input: Any, module_output: Any):
pass_module_list.append(module)

def convert_hook(tensor: torch.Tensor, hook: HookPoint):
assert hook.name is not None # Make mypy happy
pass_param_set = set()
hook_ancestors = [module for module in self.modules() if module == self or hook.name.startswith(module.name)]
for module in pass_module_list + hook_ancestors:
Expand All @@ -813,12 +813,11 @@ def convert_hook(tensor: torch.Tensor, hook: HookPoint):
for _, module in self.named_modules():
hook_handles.append(module.register_forward_hook(pass_hook))

with fake_mode:
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass

for handle in hook_handles:
handle.remove()
Expand Down Expand Up @@ -895,6 +894,47 @@ def stop_hook(tensor: torch.Tensor, hook: HookPoint):
model_out = e.tensor

return model_out, cache_dict

@contextmanager
def mount_hooked_modules(
self,
hooked_modules: List[Tuple[str, str, "HookedRootModule"]],
):
"""
A context manager for adding child hooked modules at specified hook points.

Args:
hooked_modules: List[Tuple[name, module_name, module]], where name is the name of a
hook point, module_name is the name of the module to add (which will be used to
mount the module inside the hook point), and module is the module instance to add.
A filter function as name is not allowed, since unexpected behavior may occur when
the same module is mounted at multiple hook points.
"""

for name, module_name, module in hooked_modules:
hook_point = self.mod_dict[name]
assert isinstance(
hook_point, HookPoint
)
hook_point.add_module(module_name, module)

self.setup()

try:
yield self
finally:
for name, module_name, module in hooked_modules:
if isinstance(name, str):
hook_point = self.mod_dict[name]
delattr(hook_point, module_name)
else:
for hook_point_name, hp in self.hook_dict.items():
if name(hook_point_name):
delattr(hp, module_name)
module.setup()
self.setup()




# %%
Empty file added src/lm_saes/circuit/__init__.py
Empty file.
188 changes: 188 additions & 0 deletions src/lm_saes/circuit/attributors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from abc import ABC
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Tuple, Union
import networkx as nx

import torch
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint, HookedRootModule

from lm_saes.utils.hooks import compose_hooks, detach_hook, retain_grad_hook
from lm_saes.circuit.graph import Node

class Cache:
def __init__(self, output, cache: dict[str, torch.Tensor]):
self.cache = cache
self.output = output

def tensor(self, node: Node) -> torch.Tensor:
return node.reduce(self[node.hook_point])

def grad(self, node: Node) -> torch.Tensor | None:
grad = self[node.hook_point].grad
return node.reduce(grad) if grad is not None else None

def __getitem__(self, key: Node | str | None) -> torch.Tensor:
if isinstance(key, Node):
return self.tensor(key)
return self.cache[key] if key is not None else self.output

class Attributor(ABC):
def __init__(
self,
model: HookedRootModule
):
self.model = model

def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target hook point of the model to given candidates in the model, w.r.t. the given input.

Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.

Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""
raise NotImplementedError

def cache_nodes(
self,
input: Any,
nodes: list[Node],
):
"""
Cache the activation of in the model forward pass.

Args:
input (Any): The input to the model.
nodes (list[Node]): The nodes to cache.
"""
output, cache = self.model.run_with_ref_cache(input, names_filter=[node.hook_point for node in nodes])
return Cache(output, cache)


class DirectAttributor(Attributor):
def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target node of the model to given candidates in the model, w.r.t. the given input.

Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.

Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""

threshold: int = kwargs.get("threshold", 0.1)

fwd_hooks = [(candidate.hook_point, detach_hook) for candidate in candidates if candidate.hook_point != target.hook_point]
with self.model.hooks(fwd_hooks):
cache = self.cache_nodes(input, candidates + [target])
cache[target].backward()

# Construct the circuit
circuit = nx.MultiDiGraph()
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item())
for candidate in candidates:
if candidate.hook_point == target.hook_point:
continue
grad = cache.grad(candidate)
if grad is None:
continue
attributions = grad * cache[candidate]
if len(attributions.shape) == 0:
if attributions > threshold:
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item())
circuit.add_edge(candidate, target, attribution=attributions.item(), direct_attribution=attributions.item())
else:
for index in (attributions > threshold).nonzero():
index = tuple(index.tolist())
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item())
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item(), direct_attribution=attributions[index].item())
return circuit


class HierachicalAttributor(Attributor):
def attribute(
self,
input: Any,
target: Node,
candidates: list[Node],
**kwargs
) -> nx.MultiDiGraph:
"""
Attribute the target node of the model to given candidates in the model, w.r.t. the given input.

Args:
input (Any): The input to the model.
target: The target node to attribute.
candidates: The intermediate nodes to attribute to.
**kwargs: Additional keyword arguments.

Returns:
nx.MultiDiGraph: The attributed graph, i.e. the circuit. Each node and edge should have an attribute "attribution",
showing its "importance" w.r.t. the target.
"""

threshold: int = kwargs.get("threshold", 0.1)

def generate_attribution_score_filter_hook():
v = None
def fwd_hook(tensor: torch.Tensor, hook: HookPoint):
nonlocal v
v = tensor
return tensor
def attribution_score_filter_hook(grad: torch.Tensor, hook: HookPoint):
assert v is not None, "fwd_hook must be called before attribution_score_filter_hook."
return (torch.where(v * grad > threshold, grad, torch.zeros_like(grad)),)
return fwd_hook, attribution_score_filter_hook
attribution_score_filter_hooks = {candidate: generate_attribution_score_filter_hook() for candidate in candidates}
fwd_hooks = [(candidate.hook_point, compose_hooks(attribution_score_filter_hooks[candidate][0], retain_grad_hook)) for candidate in candidates]
with self.model.hooks(
fwd_hooks=fwd_hooks,
bwd_hooks=[(candidate.hook_point, attribution_score_filter_hooks[candidate][1]) for candidate in candidates]
):
cache = self.cache_nodes(input, candidates + [target])
cache[target].backward()

# Construct the circuit
circuit = nx.MultiDiGraph()
circuit.add_node(target, attribution=cache[target].item(), activation=cache[target].item())
for candidate in candidates:
grad = cache.grad(candidate)
if grad is None:
continue
attributions = grad * cache[candidate]
if len(attributions.shape) == 0:
if attributions > threshold:
circuit.add_node(candidate, attribution=attributions.item(), activation=cache[candidate].item())
else:
for index in (attributions > threshold).nonzero():
index = tuple(index.tolist())
circuit.add_node(candidate.append_reduction(*index), attribution=attributions[index].item(), activation=cache[candidate][index].item())
circuit.add_edge(candidate.append_reduction(*index), target, attribution=attributions[index].item())

return circuit

41 changes: 41 additions & 0 deletions src/lm_saes/circuit/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from contextlib import contextmanager
from typing import Callable, Tuple, Union
import torch
from transformer_lens.hook_points import HookPoint, HookedRootModule

from lm_saes.sae import SparseAutoEncoder

@contextmanager
def apply_sae(
model: HookedRootModule,
saes: list[SparseAutoEncoder]
):
"""
Apply the sparse autoencoders to the model.
"""
fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = []
def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]:
if sae.cfg.hook_point_in == sae.cfg.hook_point_out:
def hook(tensor: torch.Tensor, hook: HookPoint):
reconstructed = sae.forward(tensor)
return reconstructed + (tensor - reconstructed).detach()
return [(sae.cfg.hook_point_in, hook)]
else:
x = None
def hook_in(tensor: torch.Tensor, hook: HookPoint):
nonlocal x
x = tensor
return tensor
def hook_out(tensor: torch.Tensor, hook: HookPoint):
nonlocal x
assert x is not None, "hook_in must be called before hook_out."
reconstructed = sae.forward(x, label=tensor)
x = None
return reconstructed + (tensor - reconstructed).detach()
return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)]
for sae in saes:
hooks = get_fwd_hooks(sae)
fwd_hooks.extend(hooks)
with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]):
with model.hooks(fwd_hooks):
yield model
Loading