Skip to content

Commit

Permalink
Merge pull request #16 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Dec 15, 2023
2 parents ecbd18d + 6ea723c commit 4bf35b2
Show file tree
Hide file tree
Showing 17 changed files with 436 additions and 119 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,18 @@ model = LanguageModel('gpt2', device_map='cuda')
with model.generate(max_new_tokens=1) as generator:
with generator.invoke('The Eiffel Tower is in the city of') as invoker:

hidden_states_pre = model.transformer.h[-1].output[0].save()
hidden_states_pre = model.transformer.h[-1].mlp.output.save()

noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)

model.transformer.h[-1].output[0] = hidden_states_pre + noise
model.transformer.h[-1].mlp.output = hidden_states_pre + noise

hidden_states_post = model.transformer.h[-1].output[0].save()
hidden_states_post = model.transformer.h[-1].mlp.output.save()

print(hidden_states_pre.value)
print(hidden_states_post.value)
```
In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values.
In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the value of `.output` with these new noised activations.

We can see the change in the results:

Expand All @@ -232,8 +232,6 @@ tensor([[[ 0.0674, -0.1741, -0.1771, ..., -0.9811, 0.1972, -1.0645],
device='cuda:0')
```

Note: Only assigment updates of tensors works with this functionality.

---
###### Multiple Token Generation

Expand Down
8 changes: 4 additions & 4 deletions docs/source/notebooks/features/setting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"We often not only want to see whats happening during computation, but intervene and edit the flow of information.\n",
"\n",
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values."
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0][:]` with these new noised values."
]
},
{
Expand All @@ -30,13 +30,13 @@
"with model.generate(max_new_tokens=1) as generator:\n",
" with generator.invoke('The Eiffel Tower is in the city of') as invoker:\n",
"\n",
" hidden_states_pre = model.transformer.h[-1].output[0].save()\n",
" hidden_states_pre = model.transformer.h[-1].output[0][:].save()\n",
"\n",
" noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)\n",
"\n",
" model.transformer.h[-1].output[0] = hidden_states_pre + noise\n",
" model.transformer.h[-1].output[0][:] = hidden_states_pre + noise\n",
"\n",
" hidden_states_post = model.transformer.h[-1].output[0].save()"
" hidden_states_post = model.transformer.h[-1].output[0][:].save()"
]
},
{
Expand Down
86 changes: 75 additions & 11 deletions docs/source/notebooks/tutorials/attribution_patching.ipynb

Large diffs are not rendered by default.

161 changes: 161 additions & 0 deletions docs/source/notebooks/tutorials/sae.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ Tutorials
notebooks/tutorials/future_lens.ipynb
notebooks/tutorials/function_vectors.ipynb
notebooks/tutorials/dictionary_learning.ipynb

.. toctree::
:hidden:

notebooks/tutorials/sae.ipynb
7 changes: 3 additions & 4 deletions src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def __enter__(self) -> Invoker:
module._backward_output = None
module._backward_input = None

batched_inputs = self.tracer.model._batched_inputs(self.input)

self.tracer.batch_size = len(batched_inputs)
self.tracer.batched_input.extend(batched_inputs)
self.tracer.batch_start += self.tracer.batch_size

self.tracer.batched_input, self.tracer.batch_size = self.tracer.model._batch_inputs(self.input, self.tracer.batched_input)

return self

Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class Tracer(AbstractContextManager):
args (List[Any]): Positional arguments to be passed to function that executes the model.
kwargs (Dict[str,Any]): Keyword arguments to be passed to function that executes the model.
batch_size (int): Batch size of the most recent input. Used by Module to create input/output proxies.
batch_start (int): Batch start of the most recent input. Used by Module to create input/output proxies.
generation_idx (int): Current generation idx for multi-iteration generation. Used by Module to create input/output proxies.
batched_input (List[Any]): Batched version of all inputs involved in this Tracer.
batched_input Any: Batched version of all inputs involved in this Tracer.
output (Any): Output of execution after __exit__
"""

Expand All @@ -42,8 +43,9 @@ def __init__(
self.graph = Graph(self.model.meta_model, proxy_class=InterventionProxy, validate=validate)

self.batch_size: int = 0
self.batch_start: int = 0
self.generation_idx: int = 0
self.batched_input: List[Any] = []
self.batched_input: Any = None

self.output = None

Expand Down
53 changes: 51 additions & 2 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
"""
from __future__ import annotations

from contextlib import AbstractContextManager
import inspect
from contextlib import AbstractContextManager
from typing import Any, Callable, Collection, List, Tuple, Union

import torch
from torch.utils.hooks import RemovableHandle

from . import util
from .tracing.Graph import Graph
from .tracing.Node import Node
from .tracing.Proxy import Proxy


Expand Down Expand Up @@ -81,7 +82,7 @@ def retain_grad(self):

# We need to set the values of self to values of self to add this into the computation graph so grad flows through it
# This is because in intervene(), we call .narrow on activations which removes it from the grad path
self.node.graph.add(target=Proxy.proxy_update, args=[self.node, self.node])
self[:] = self

@property
def token(self) -> TokenIndexer:
Expand Down Expand Up @@ -136,6 +137,49 @@ def value(self) -> Any:
return self.node.value


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
if graph.swap is not None:

def concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list) or isinstance(values[0], tuple):
return [
concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], dict):
return {
key: concat([value[key] for value in values])
for key in values[0].keys()
}

pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

def get_value(node: Node):
value = node.value

node.set_value(True)

return value

value = util.apply(graph.swap, get_value, Node)

activations = concat([pre, value, post])

graph.swap = None

return activations


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
"""Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.
Expand Down Expand Up @@ -181,6 +225,11 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
torch.Tensor,
)
)

# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
activations = check_swap(graph, activations, batch_start, batch_size)

return activations


Expand Down
14 changes: 9 additions & 5 deletions src/nnsight/models/AbstractModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import gc
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union

import accelerate
import torch
Expand Down Expand Up @@ -438,14 +438,18 @@ def _example_input(self) -> Any:
raise NotImplementedError()

@abstractmethod
def _batched_inputs(self, prepared_inputs: Any) -> List[Any]:
"""Abstract to return a version of the prepared inputs from ``._prepare_inputs(...)`` that can be batched with others.
For example with a LanguageModel, prepare_inputs returns a dictionary. The implementation for this method just returns the 'input_ids'.
def _batch_inputs(
self, prepared_inputs: Any, batched_inputs: Any
) -> Tuple[Any, int]:
"""Abstract to return a batched version of the prepared inputs from ``._prepare_inputs(...)`` that can be batched with others as well as the size of the batch being added.
Should batch prepared_inputs with batched_inputs and return it with the current batch_size.
Args:
prepared_inputs (Any): Inputs from ``._prepare_inputs(...)``
batched_inputs (Any): Current state of batched_inputs
Returns:
List[Any]: Batched version of prepared_inputs.
Any: prepared_inputs batched with batched_inputs.
int: Batch size of prepared_inputs.
"""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion src/nnsight/models/DiffuserModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _generation(

return self.local_model.pipeline(inputs, *args, **kwargs)

def _batched_inputs(self, prepared_inputs: BatchEncoding) -> torch.Tensor:
def _batch_inputs(self, prepared_inputs: BatchEncoding) -> torch.Tensor:
return prepared_inputs if not isinstance(prepared_inputs, str) else [prepared_inputs]

def _example_input(self) -> Dict[str, torch.Tensor]:
Expand Down
81 changes: 66 additions & 15 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, PretrainedConfig, PreTrainedModel,
PreTrainedTokenizer)

from transformers.models.auto import modeling_auto
from .AbstractModel import AbstractModel


class LanguageModel(AbstractModel):
"""LanguageModels are nnsight wrappers around AutoModelForCausalLM models.
"""LanguageModels are nnsight wrappers around transformer auto models.
Inputs can be in the form of:
Prompt: (str)
Expand All @@ -25,21 +25,23 @@ class LanguageModel(AbstractModel):
If using a custom model, you also need to provide the tokenizer like ``LanguageModel(custom_model, tokenizer=tokenizer)``
Calls to generate pass arguments downstream to :func:`AutoModelForCausalLM.generate`
Calls to generate pass arguments downstream to :func:`GenerationMixin.generate`
Attributes:
config (PretrainedConfig): Huggingface config file loaded from repository or checkpoint.
tokenizer (PreTrainedTokenizer): Tokenizer for LMs.
meta_model (PreTrainedModel): Meta version of underlying AutoModelForCausalLM model.
local_model (PreTrainedModel): Local version of underlying AutoModelForCausalLM model.
automodel (type): AutoModel type from transformer auto models.
meta_model (PreTrainedModel): Meta version of underlying auto model.
local_model (PreTrainedModel): Local version of underlying auto model.
"""

def __init__(self, *args, tokenizer=None, **kwargs) -> None:
def __init__(self, *args, tokenizer=None, automodel=AutoModelForCausalLM, **kwargs) -> None:
self.config: PretrainedConfig = None
self.tokenizer: PreTrainedTokenizer = tokenizer
self.meta_model: PreTrainedModel = None
self.local_model: PreTrainedModel = None
self.automodel = automodel if not isinstance(automodel, str) else getattr(modeling_auto, automodel)

super().__init__(*args, **kwargs)

Expand All @@ -54,14 +56,14 @@ def _load_meta(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel:
)
self.tokenizer.pad_token = self.tokenizer.eos_token

return AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
return self.automodel.from_config(self.config, trust_remote_code=True)

def _load_local(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel:
return AutoModelForCausalLM.from_pretrained(
return self.automodel.from_pretrained(
repoid_or_path, *args, config=self.config, **kwargs
)

def _prepare_inputs(
def _tokenize(
self,
inputs: Union[
str,
Expand All @@ -72,9 +74,9 @@ def _prepare_inputs(
torch.Tensor,
Dict[str, Any],
],
) -> BatchEncoding:
if isinstance(inputs, collections.abc.Mapping):
return BatchEncoding(inputs)
):
if isinstance(inputs, BatchEncoding):
return inputs

if isinstance(inputs, str) or (
isinstance(inputs, list) and isinstance(inputs[0], int)
Expand All @@ -91,11 +93,60 @@ def _prepare_inputs(

return self.tokenizer(inputs, return_tensors="pt", padding=True)

def _batched_inputs(self, prepared_inputs: BatchEncoding) -> torch.Tensor:
return prepared_inputs["input_ids"]
def _prepare_inputs(
self,
inputs: Union[
str,
List[str],
List[List[str]],
List[int],
List[List[int]],
torch.Tensor,
Dict[str, Any],
BatchEncoding,
],
labels: Any = None,
**kwargs,
) -> BatchEncoding:
if isinstance(inputs, dict):
_inputs = self._tokenize(inputs["input_ids"])

_inputs = self._tokenize(_inputs)

if "labels" in inputs:
labels = self._tokenize(inputs["labels"])
labels = self._tokenize(labels)
_inputs["labels"] = labels["input_ids"]

return _inputs

inputs = self._tokenize(inputs)

if labels is not None:
labels = self._tokenize(labels)

inputs["labels"] = labels["input_ids"]

return inputs

def _batch_inputs(
self, prepared_inputs: BatchEncoding, batched_inputs: Dict
) -> torch.Tensor:
if batched_inputs is None:
batched_inputs = {"input_ids": []}

if "labels" in prepared_inputs:
batched_inputs["labels"] = []

batched_inputs["input_ids"].extend(prepared_inputs["input_ids"])

if "labels" in prepared_inputs:
batched_inputs["labels"].extend(prepared_inputs["labels"])

return batched_inputs, len(prepared_inputs["input_ids"])

def _example_input(self) -> Dict[str, torch.Tensor]:
return {"input_ids": torch.tensor([[0]])}
return BatchEncoding({"input_ids": torch.tensor([[0]]), "labels": torch.tensor([[0]])})

def _scan(self, prepared_inputs, *args, **kwargs) -> None:
# TODO
Expand Down
Loading

0 comments on commit 4bf35b2

Please sign in to comment.