Skip to content

Commit

Permalink
Merge branch 'swap' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 15, 2023
2 parents e455389 + 7229cf0 commit 6ea723c
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 77 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,18 @@ model = LanguageModel('gpt2', device_map='cuda')
with model.generate(max_new_tokens=1) as generator:
with generator.invoke('The Eiffel Tower is in the city of') as invoker:

hidden_states_pre = model.transformer.h[-1].output[0].save()
hidden_states_pre = model.transformer.h[-1].mlp.output.save()

noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)

model.transformer.h[-1].output[0] = hidden_states_pre + noise
model.transformer.h[-1].mlp.output = hidden_states_pre + noise

hidden_states_post = model.transformer.h[-1].output[0].save()
hidden_states_post = model.transformer.h[-1].mlp.output.save()

print(hidden_states_pre.value)
print(hidden_states_post.value)
```
In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values.
In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the value of `.output` with these new noised activations.

We can see the change in the results:

Expand All @@ -232,8 +232,6 @@ tensor([[[ 0.0674, -0.1741, -0.1771, ..., -0.9811, 0.1972, -1.0645],
device='cuda:0')
```

Note: Only assigment updates of tensors works with this functionality.

---
###### Multiple Token Generation

Expand Down
8 changes: 4 additions & 4 deletions docs/source/notebooks/features/setting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"We often not only want to see whats happening during computation, but intervene and edit the flow of information.\n",
"\n",
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values."
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0][:]` with these new noised values."
]
},
{
Expand All @@ -30,13 +30,13 @@
"with model.generate(max_new_tokens=1) as generator:\n",
" with generator.invoke('The Eiffel Tower is in the city of') as invoker:\n",
"\n",
" hidden_states_pre = model.transformer.h[-1].output[0].save()\n",
" hidden_states_pre = model.transformer.h[-1].output[0][:].save()\n",
"\n",
" noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)\n",
"\n",
" model.transformer.h[-1].output[0] = hidden_states_pre + noise\n",
" model.transformer.h[-1].output[0][:] = hidden_states_pre + noise\n",
"\n",
" hidden_states_post = model.transformer.h[-1].output[0].save()"
" hidden_states_post = model.transformer.h[-1].output[0][:].save()"
]
},
{
Expand Down
53 changes: 51 additions & 2 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
"""
from __future__ import annotations

from contextlib import AbstractContextManager
import inspect
from contextlib import AbstractContextManager
from typing import Any, Callable, Collection, List, Tuple, Union

import torch
from torch.utils.hooks import RemovableHandle

from . import util
from .tracing.Graph import Graph
from .tracing.Node import Node
from .tracing.Proxy import Proxy


Expand Down Expand Up @@ -81,7 +82,7 @@ def retain_grad(self):

# We need to set the values of self to values of self to add this into the computation graph so grad flows through it
# This is because in intervene(), we call .narrow on activations which removes it from the grad path
self.node.graph.add(target=Proxy.proxy_update, args=[self.node, self.node])
self[:] = self

@property
def token(self) -> TokenIndexer:
Expand Down Expand Up @@ -136,6 +137,49 @@ def value(self) -> Any:
return self.node.value


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
if graph.swap is not None:

def concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list) or isinstance(values[0], tuple):
return [
concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], dict):
return {
key: concat([value[key] for value in values])
for key in values[0].keys()
}

pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

def get_value(node: Node):
value = node.value

node.set_value(True)

return value

value = util.apply(graph.swap, get_value, Node)

activations = concat([pre, value, post])

graph.swap = None

return activations


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
"""Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.
Expand Down Expand Up @@ -181,6 +225,11 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
torch.Tensor,
)
)

# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
activations = check_swap(graph, activations, batch_start, batch_size)

return activations


Expand Down
20 changes: 12 additions & 8 deletions src/nnsight/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ def output(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.output.node.graph.add(
target=Proxy.proxy_update,
args=[self.output.node, value],
target="swp", args=[self.output.node, value], value=True
)

self._output = None

@property
def input(self) -> InterventionProxy:
"""
Expand Down Expand Up @@ -166,10 +167,11 @@ def input(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.input.node.graph.add(
target=Proxy.proxy_update,
args=[self.input.node, value],
target="swp", args=[self.input.node, value], value=True
)

self._input = None

@property
def backward_output(self) -> InterventionProxy:
"""
Expand Down Expand Up @@ -206,10 +208,11 @@ def backward_output(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.backward_output.node.graph.add(
target=Proxy.proxy_update,
args=[self.backward_output.node, value],
target="swp", args=[self.backward_output.node, value], value=True
)

self._backward_output = None

@property
def backward_input(self) -> InterventionProxy:
"""
Expand Down Expand Up @@ -246,10 +249,11 @@ def backward_input(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.backward_input.node.graph.add(
target=Proxy.proxy_update,
args=[self.backward_input.node, value],
target="swp", args=[self.backward_input.node, value], value=True
)

self._backward_input = None

@property
def graph(self) -> Graph:
if self._graph is None:
Expand Down
42 changes: 10 additions & 32 deletions src/nnsight/tracing/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Graph:
* 'module' : There should only be the single root module as a node in the graph for tracing. Added on __init__ and when compiling, the node's value is set to to be whatever module that is being interleaved with this computation graph.
* 'argument' : There can be multiple argument nodes. Their first argument needs to be the argument name which acts as a key in graph.argument_node_names which maps to a list of names for nodes that depend on it's value. These nodes values need to be set outside of the computation graph as entry points to kick of the execution of the graph.
* 'rtn' : Should only be one 'rtn' target named node as this is what is used.
* 'swp' : swp nodes indicate populating the graph's swap attribute. When executed, its value is not set. Logic involving the swap value should set its value after using it.
* 'null' : Null nodes never get executed and therefore their listeners never get destroyed.
Attributes:
Expand All @@ -30,6 +30,7 @@ class Graph:
module_proxy (Proxy): Proxy for given root meta module.
argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it.
generation_idx (int): Current generation index.
swap (Any): Attribute to store swap values from 'swp' nodes.
"""

@staticmethod
Expand Down Expand Up @@ -104,36 +105,13 @@ def get_argument_value(param: inspect.Parameter, idx: int):
# Run forward with root module proxy and arguments
output: Proxy = forward(graph.module_proxy, *arguments)

# Get proxy_value for return
value = util.apply(output, lambda x: x.node.proxy_value, Proxy)

# Create the 'rtn_0' return proxy
# Create the 'swap' return proxy
return_proxy = graph.add(
graph=graph, value=value, target=Graph.rtn, args=output
)

# This is how we tell the graph not to destroy a proxy after it's listeners are completed.
# Create a 'null' proxy. The return proxy listens to the 'null' proxy with args=[return_proxy.node] but 'null' will never be completed.
graph.add(
graph=graph,
value=None,
target="null",
args=[return_proxy.node],
graph=graph, value=True, target="swp", args=[output.node, output.node]
)

return graph

@staticmethod
def rtn(*args, **kwargs):
"""
Function to just pass through data for returning data in a graph forward method.
Returns:
_type_: _description_
"""

return args

def __init__(
self,
module: torch.nn.Module,
Expand All @@ -151,6 +129,8 @@ def __init__(

self.generation_idx = 0

self.swap: Any = None

def increment(self) -> None:
"""Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation."""
self.generation_idx += 1
Expand Down Expand Up @@ -199,7 +179,7 @@ def add(
Proxy: Proxy for the added node.
Raises:
ValueError: If more than one reserved "rtn" or "module" nodes are added to the graph.
ValueError: If more than one reserved "module" nodes are added to the graph.
"""

# If we're validating and the user did not provide a value, execute the given target with meta proxy values to compute new proxy_value.
Expand All @@ -217,8 +197,6 @@ def add(
if target_name not in self.name_idx:
self.name_idx[target_name] = 0
else:
if target_name == "rtn":
raise ValueError("Can only have one return ('rtn') node.")
if target_name == "module":
raise ValueError("Can only have one module node.")

Expand Down Expand Up @@ -293,9 +271,9 @@ def forward(*args, **kwargs):
if key in self.argument_node_names:
self.nodes[self.argument_node_names[key][0]].set_value(arg)

# 'rtn_0' should have the value we need to return.
return_value = self.nodes["rtn_0"].value
self.nodes["rtn_0"].destroy()
# should have the value we need to return.
return_value = self.swap
self.swap.set_value(True)
return return_value

# Replace forward method with custom graph execution method.
Expand Down
4 changes: 4 additions & 0 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def execute(self) -> None:
# We se a nodes target to 'null' if we don't want it to be executed and therefore never done
if self.target == "null":
return
elif self.target == "swp":
self.graph.swap = self.args[1]

return

# Prepare arguments.
args, kwargs = self.prepare_inputs()
Expand Down
25 changes: 3 additions & 22 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,6 @@ class Proxy:
node (Node): This proxy's node.
"""

@staticmethod
def proxy_update(value1: Any, value2: Any) -> None:
"""Updates Tensor values with other Tensor values.
Args:
value1 (Any): Collection with Tensors to update.
value2 (Any): Collection with Tensors to pull values from.
"""
if isinstance(value1, torch.Tensor):
value1[:] = value2
elif isinstance(value1, list) or isinstance(value1, tuple):
for value_idx in range(len(value1)):
Proxy.proxy_update(value1[value_idx], value2[value_idx])
elif isinstance(value1, dict):
for key in value1:
Proxy.proxy_update(value1[key], value2[key])

@staticmethod
def proxy_call(callable: Callable, *args, **kwargs) -> None:
return callable(*args, **kwargs)
Expand Down Expand Up @@ -76,11 +59,9 @@ def __getitem__(self, key: Union[Proxy, Any]) -> Proxy:
)

def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None:
item_proxy = self[key]

item_proxy.node.graph.add(
target=Proxy.proxy_update,
args=[item_proxy.node, value],
self.node.graph.add(
target=operator.setitem,
args=[self.node, key, value],
)

def __getattr__(self, key: Union[Proxy, Any]) -> Proxy:
Expand Down
22 changes: 19 additions & 3 deletions tests/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_save(gpt2: nnsight.LanguageModel):
assert hs_input.value.ndim == 3


def test_set(gpt2: nnsight.LanguageModel):
def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str):
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke("Hello world") as invoker:
with generator.invoke(MSG_prompt) as invoker:
pre = gpt2.transformer.h[-1].output[0].clone().save()

gpt2.transformer.h[-1].output[0] = 0
gpt2.transformer.h[-1].output[0][:] = 0

post = gpt2.transformer.h[-1].output[0].save()

Expand All @@ -55,6 +55,22 @@ def test_set(gpt2: nnsight.LanguageModel):
assert output != "Madison Square Garden is located in the city of New"


def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str):
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke(MSG_prompt) as invoker:
pre = gpt2.transformer.wte.output.clone().save()

gpt2.transformer.wte.output = gpt2.transformer.wte.output * 0

post = gpt2.transformer.wte.output.save()

output = gpt2.tokenizer.decode(generator.output[0])

assert not (pre.value == 0).all().item()
assert (post.value == 0).all().item()
assert output != "Madison Square Garden is located in the city of New"


def test_adhoc_module(gpt2: nnsight.LanguageModel):
with gpt2.generate() as generator:
with generator.invoke("The Eiffel Tower is in the city of") as invoker:
Expand Down

0 comments on commit 6ea723c

Please sign in to comment.