diff --git a/README.md b/README.md deleted file mode 100644 index 311cc50..0000000 --- a/README.md +++ /dev/null @@ -1,443 +0,0 @@ -# NDIF Repository -## engine API - -The `engine/` directory contains the engine package for interpreting and manipulating the internals of large language models. - -- `engine/model_checkpoints` is set to be the default huggingface hub cache directory. Contains by default models found on the NDIF server with their respective configurations. - -#### Installation - -Install this package through pip by running: - -`pip install git+https://github.com/JadenFiotto-Kaufman/ndif` - -#### Examples - -Here is a simple example where we run the engine API locally on gpt2 and save the hidden states of the last layer: - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=1) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states = model.transformer.h[-1].output[0].save() - -output = generator.output -hidden_states = hidden_states.value -``` - -Lets go over this piece by piece. - -We import the `Model` object from the `engine` module and create a gpt2 model using the huggingface repo ID for gpt2, `'gpt2'`. This accepts arguments to create the model including `device_map` to specify which device to run on. - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2',device_map='cuda') -``` - -Then, we create a generation context block by calling `.generate(...)` on the model object. This denotes we wish to actually generate tokens given some prompts. - -Keyword arguments are passed downstream to [AutoModelForCausalLM.generate(...)](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate). Refer to the linked docs for reference. - - -```python -with model.generate(max_new_tokens=3) as generator: -``` - -Now calling `.generate(...)` does not actually initialize or run the model. Only after the `with generator` block is exited, is the acually model loaded and ran. All operations in the block are "proxies" which essentially creates a graph of operations we wish to carry out later. - - -Within the generation context, we create invocation contexts to specify the actual prompts we want to run: - - -```python -with generator.invoke('The Eiffel Tower is in the city of') as invoker: -``` - -Within this context, all operations/interventions will be applied to the processing of this prompt. - -```python -hidden_states = model.transformer.h[-1].output[0].save() -``` - -On this line were saying, access the last layer of the transformer `model.transformer.h[-1]`, access its output `.output`, index it at 0 `.output[0]`, and save it `.save()` - -A few things, we can see the module tree of the model by printing the model. This allows us to know what attributes to access to get to the module we need. -Running `print(model)` results in: - -``` -GPT2LMHeadModel( - (transformer): GPT2Model( - (wte): Embedding(50257, 768) - (wpe): Embedding(1024, 768) - (drop): Dropout(p=0.1, inplace=False) - (h): ModuleList( - (0-11): 12 x GPT2Block( - (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) - (attn): GPT2Attention( - (c_attn): Conv1D() - (c_proj): Conv1D() - (attn_dropout): Dropout(p=0.1, inplace=False) - (resid_dropout): Dropout(p=0.1, inplace=False) - ) - (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) - (mlp): GPT2MLP( - (c_fc): Conv1D() - (c_proj): Conv1D() - (act): NewGELUActivation() - (dropout): Dropout(p=0.1, inplace=False) - ) - ) - ) - (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True) - ) - (lm_head): Linear(in_features=768, out_features=50257, bias=False) -) -``` - -`.output` returns a proxy for the output of this module. This essentially means were saying, when we get to the output of this module during inference, grab it and perform any operations we define on it (which also become proxies). There are two operational proxies here, one for getting the 0th index of the output, and one for saving the output. We take the 0th index because the output of gpt2 transformer layers are a tuple where the first index are the actual hidden states (last two indicies are from attention). We can call `.shape` on any proxies to get what shape the value will eventually be. -Running `print(model.transformer.h[-1].output.shape)` returns `(torch.Size([1, 10, 768]), (torch.Size([1, 12, 10, 64]), torch.Size([1, 12, 10, 64])))` - -During processing of the intervention computational graph we are building, when the value of a proxy is no longer ever needed, its value is dereferenced and destroyed. However calling `.save()` on the proxy informs the computation graph to clone the value of this proxy and never destroy it, allowing us to access to value after generation. - -After exiting the generator context, the model is ran with the specified arguments and intervention graph. `generator.output` is populated with the actual output and `hidden_states.value` will contain the value. - -```python -output = generator.output -hidden_states = hidden_states.value - -print(output) -print(hidden_states) -``` - -returns: - -``` -tensor([[ 464, 412, 733, 417, 8765, 318, 287, 262, 1748, 286, 6342]], - device='cuda:0') -tensor([[[ 0.0505, -0.1728, -0.1690, ..., -1.0096, 0.1280, -1.0687], - [ 8.7494, 2.9057, 5.3024, ..., -8.0418, 1.2964, -2.8677], - [ 0.2960, 4.6686, -3.6642, ..., 0.2391, -2.6064, 3.2263], - ..., - [ 2.1537, 6.8917, 3.8651, ..., 0.0588, -1.9866, 5.9188], - [-0.4460, 7.4285, -9.3065, ..., 2.0528, -2.7946, 0.5556], - [ 6.6286, 1.7258, 4.7969, ..., 7.6714, 3.0682, 2.0481]]], - device='cuda:0') -``` - - - ---- - -###### Operations - -Most* basic operations and torch operations work on proxies and are added to the computation graph. - -```python -from engine import LanguageModel -import torch - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=1) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states_pre = model.transformer.h[-1].output[0].save() - - hs_sum = torch.sum(hidden_states_pre).save() - - hs_edited = hidden_states_pre + hs_sum - - hs_edited = hs_edited.save() - -print(hidden_states_pre.value) -print(hs_sum.value) -print(hs_edited.value) -``` - -In this example we get the sum of the hidden states and add them to the hidden_states themselves (for whatever reason). By saving the various steps, we can see how the values change. - -``` -tensor([[[ 0.0505, -0.1728, -0.1690, ..., -1.0096, 0.1280, -1.0687], - [ 8.7494, 2.9057, 5.3024, ..., -8.0418, 1.2964, -2.8677], - [ 0.2960, 4.6686, -3.6642, ..., 0.2391, -2.6064, 3.2263], - ..., - [ 2.1537, 6.8917, 3.8651, ..., 0.0588, -1.9866, 5.9188], - [-0.4460, 7.4285, -9.3065, ..., 2.0528, -2.7946, 0.5556], - [ 6.6286, 1.7258, 4.7969, ..., 7.6714, 3.0682, 2.0481]]], - device='cuda:0') -tensor(501.2957, device='cuda:0') -tensor([[[501.3461, 501.1229, 501.1267, ..., 500.2860, 501.4237, 500.2270], - [510.0451, 504.2014, 506.5981, ..., 493.2538, 502.5920, 498.4279], - [501.5916, 505.9643, 497.6315, ..., 501.5348, 498.6892, 504.5219], - ..., - [503.4493, 508.1874, 505.1607, ..., 501.3545, 499.3091, 507.2145], - [500.8496, 508.7242, 491.9892, ..., 503.3485, 498.5010, 501.8512], - [507.9242, 503.0215, 506.0926, ..., 508.9671, 504.3639, 503.3438]]], - device='cuda:0') - -``` - ---- -###### Setting - -We often not only want to see whats happening during computation, but intervene and edit the flow of information. - -```python -from engine import LanguageModel -import torch - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=1) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states_pre = model.transformer.h[-1].output[0].save() - - noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape) - - model.transformer.h[-1].output[0] = hidden_states_pre + noise - - hidden_states_post = model.transformer.h[-1].output[0].save() - -print(hidden_states_pre.value) -print(hidden_states_post.value) -``` -In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values. - -We can see the change in the results: - -``` -tensor([[[ 0.0505, -0.1728, -0.1690, ..., -1.0096, 0.1280, -1.0687], - [ 8.7494, 2.9057, 5.3024, ..., -8.0418, 1.2964, -2.8677], - [ 0.2960, 4.6686, -3.6642, ..., 0.2391, -2.6064, 3.2263], - ..., - [ 2.1537, 6.8917, 3.8651, ..., 0.0588, -1.9866, 5.9188], - [-0.4460, 7.4285, -9.3065, ..., 2.0528, -2.7946, 0.5556], - [ 6.6286, 1.7258, 4.7969, ..., 7.6714, 3.0682, 2.0481]]], - device='cuda:0') -tensor([[[ 0.0674, -0.1741, -0.1771, ..., -0.9811, 0.1972, -1.0645], - [ 8.7080, 2.9067, 5.2924, ..., -8.0253, 1.2729, -2.8419], - [ 0.2611, 4.6911, -3.6434, ..., 0.2295, -2.6007, 3.2635], - ..., - [ 2.1859, 6.9242, 3.8666, ..., 0.0556, -2.0282, 5.8863], - [-0.4568, 7.4101, -9.3698, ..., 2.0630, -2.7971, 0.5522], - [ 6.6764, 1.7416, 4.8027, ..., 7.6507, 3.0754, 2.0218]]], - device='cuda:0') -``` - -Note: Only assigment updates of tensors works with this functionality. - ---- -###### Multiple Token Generation - -When generating more than one token, use `invoker.next()` to denote following interventions should be applied to the subsequent generations. - -Here we again generate using gpt2, but generate three tokens and save the hidden states of the last layer for each one: - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=3) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states1 = model.transformer.h[-1].output[0].save() - - invoker.next() - - hidden_states2 = model.transformer.h[-1].output[0].save() - - invoker.next() - - hidden_states3 = model.transformer.h[-1].output[0].save() - - -output = generator.output -hidden_states1 = hidden_states1.value -hidden_states2 = hidden_states2.value -hidden_states3 = hidden_states3.value -``` ---- - -###### Token Based Indexing - - -When indexing hidden states for specific tokens, use `.token[]` or `.t[]`. -This is because if there are multiple invocations, padding is performed on the left side so these helper functions index from the back. - -Here we just get the hidden states of the first token: - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=1) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states = model.transformer.h[-1].output[0].t[0].save() - -output = generator.output -hidden_states = hidden_states.value -``` - ---- - -###### Cross Prompt Intervention - - -Intervention operations work cross prompt! Use two invocations within the same generation block and operations can work between them. - -In this case, we grab the token embeddings coming from the first prompt, `"Madison square garden is located in the city of New"` and replace the embeddings of the second prompt with them. - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=3) as generator: - - with generator.invoke("Madison square garden is located in the city of New") as invoker: - - embeddings = model.transformer.wte.output - - with generator.invoke("_ _ _ _ _ _ _ _ _ _") as invoker: - - model.transformer.wte.output = embeddings - -print(model.tokenizer.decode(generator.output[0])) -print(model.tokenizer.decode(generator.output[1])) -``` - -This results in: - -``` -Madison square garden is located in the city of New York City. -_ _ _ _ _ _ _ _ _ _ York City. -``` - -We also could have entered a pre-saved embedding tensor as shown here: - -```python -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda') - -with model.generate(max_new_tokens=3) as generator: - - with generator.invoke("Madison square garden is located in the city of New") as invoker: - - embeddings = model.transformer.wte.output.save() - -print(model.tokenizer.decode(generator.output[0])) -print(embeddings.value) - -with model.generate(max_new_tokens=3) as generator: - - with generator.invoke("_ _ _ _ _ _ _ _ _ _") as invoker: - - model.transformer.wte.output = embeddings.value - -print(model.tokenizer.decode(generator.output[0])) - -``` ---- - -###### Ad-hoc Module - -Another thing we can do is apply modules in the model's module tree at any point during computation, even if it's out of order. - -```python -from engine import LanguageModel -import torch - -model = LanguageModel("gpt2", device_map='cuda') - -with model.generate() as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states = model.transformer.h[-1].output[0] - hidden_states = model.lm_head(model.transformer.ln_f(hidden_states)).save() - tokens = torch.softmax(hidden_states, dim=2).argmax(dim=2).save() - -print(hidden_states.value) -print(tokens.value) -print(model.tokenizer.decode(tokens.value[0])) - -``` - -Here we get the hidden states of the last layer like usual. We also chain apply `model.transformer.ln_f` and `model.lm_head` in order to "decode" the hidden states into vocabularly space. -Applying softmax and then argmax allows us to then transform the vocabulary space hidden states into actually tokens which we can then use the tokenizer to decode. - -The output looks like: - -``` -tensor([[[ -36.2874, -35.0114, -38.0793, ..., -40.5163, -41.3759, - -34.9193], - [ -68.8886, -70.1562, -71.8408, ..., -80.4195, -78.2552, - -71.1206], - [ -82.2950, -81.6519, -83.9941, ..., -94.4878, -94.5194, - -85.6998], - ..., - [-113.8675, -111.8628, -113.6634, ..., -116.7652, -114.8267, - -112.3621], - [ -81.8531, -83.3006, -91.8192, ..., -92.9943, -89.8382, - -85.6898], - [-103.9307, -102.5054, -105.1563, ..., -109.3099, -110.4195, - -103.1395]]], device='cuda:0') -tensor([[ 198, 12, 417, 8765, 318, 257, 262, 3504, 7372, 6342]], - device='cuda:0') - --el Tower is a the middle centre Paris -``` - ---- - -###### Running Remotely - - -Running the engine API remotely on LLaMA 65b and saving the hidden states of the last layer: - -```python -from engine import LanguageModel - -model = LanguageModel('decapoda-research/llama-65b-hf') -with model.generate(server=True, max_new_tokens=1) as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states = model.model.layers[-1].output[0].save() - -output = generator.output -hidden_states = hidden_states.value -``` - -More examples can be found in `engine/examples/` - -## Inference Server - -Source for the NDIF server is found in the `server/` directory. - -- Edit `server/config.yaml` for your requirements. - - `PORT` : Flask port - - `RESPONSE_PATH` : Where to store disk offloaded response data - -#### Installation - -Clone this repository and create the `ndif` conda environment: - -```bash -cd ndif -conda env create -f server/environment.yaml -``` - -Start the server with: - -```bash -python -m server -``` diff --git a/server/ResponseDict.py b/ResponseDict.py similarity index 100% rename from server/ResponseDict.py rename to ResponseDict.py diff --git a/server/__init__.py b/__init__.py similarity index 100% rename from server/__init__.py rename to __init__.py diff --git a/server/__main__.py b/__main__.py similarity index 100% rename from server/__main__.py rename to __main__.py diff --git a/server/app.py b/app.py similarity index 100% rename from server/app.py rename to app.py diff --git a/server/config.yaml b/config.yaml similarity index 100% rename from server/config.yaml rename to config.yaml diff --git a/server/download_models.py b/download_models.py similarity index 100% rename from server/download_models.py rename to download_models.py diff --git a/engine/Module.py b/engine/Module.py deleted file mode 100644 index 7a9d111..0000000 --- a/engine/Module.py +++ /dev/null @@ -1,193 +0,0 @@ -from __future__ import annotations - -from typing import Any, Type, Union - -import torch - -from . import util -from .contexts.Generator import Generator -from .fx.Graph import Graph -from .fx.Node import Node -from .intervention import InterventionProxy - - -class Module(torch.nn.Module): - """_summary_ - - Attributes: - generator (Generator): _description_ - module_path (str): _description_ - output_shape (torch.Size): _description_ - output_type (Type): _description_ - _output (Proxy): _description_ - """ - - def __init__(self) -> None: - self.module_path: str = None - self.input_shape: torch.Size = None - self.input_type: Type = None - self.output_shape: torch.Size = None - self.output_type: Type = None - - self._output: InterventionProxy = None - self._input: InterventionProxy = None - self._graph: Graph = None - - self.generator: Generator = None - - def __call__(self, *args: Any, **kwds: Any) -> Any: - """Override __call__ to check for InterventionProxy arguments. If there are any, we should return an - InterventionProxy denoting we want to call the given module with arguments. - - Returns: - Any: _description_ - """ - proxy = any( - isinstance(x, InterventionProxy) for x in list(args) + list(kwds.values()) - ) - - if proxy: - module_proxy = getattr(self.generator.graph.module_proxy, self.module_path) - - return module_proxy.forward(*args, **kwds) - - return super().__call__(*args, **kwds) - - @property - def output(self) -> InterventionProxy: - """ - Calling denotes the user wishes to get the output of this module and therefore we create a Proxy of that request. - Only generates a proxy the first time it is references otherwise return the already set one. - - Returns: - Proxy: _description_ - """ - if self._output is None: - self._output = self.generator.graph.add( - graph=self.generator.graph, - value=util.apply( - self.output_shape, - lambda x: torch.empty(x, device="meta"), - torch.Size, - ), - target="argument", - args=[ - f"{self.module_path}.output.{self.generator.generation_idx}", - self.generator.batch_size, - len(self.generator.prompts) - self.generator.batch_size, - ], - ) - - return self._output - - @output.setter - def output(self, value: Union[InterventionProxy, Any]) -> None: - """ - Calling denotes the user wishes to set the output of this module and therefore we create a Proxy of that request. - - Args: - value (Union[Proxy, Any]): _description_ - """ - - Node.update( - self.output.node.proxy_value, self.output.node.prepare_proxy_values(value) - ) - - self.output.node.graph.add( - graph=self.output.node.graph, - value=self.output.node.proxy_value, - target=Node.update, - args=[self.output.node, value], - ) - - @property - def input(self) -> InterventionProxy: - """ - Calling denotes the user wishes to get the input of this module and therefore we create a Proxy of that request. - Only generates a proxy the first time it is references otherwise return the already set one. - - Returns: - Proxy: _description_ - """ - if self._input is None: - self._input = self.generator.graph.add( - graph=self.generator.graph, - value=util.apply( - self.input_shape, - lambda x: torch.empty(x, device="meta"), - torch.Size, - ), - target="argument", - args=[ - f"{self.module_path}.input.{self.generator.generation_idx}", - self.generator.batch_size, - len(self.generator.prompts) - self.generator.batch_size, - ], - ) - - return self._input - - @input.setter - def input(self, value: Union[InterventionProxy, Any]) -> None: - """ - Calling denotes the user wishes to set the input of this module and therefore we create a Proxy of that request. - - Args: - value (Union[Proxy, Any]): _description_ - """ - - Node.update( - self.input.node.proxy_value, self.input.node.prepare_proxy_values(value) - ) - - self.input.node.graph.add( - graph=self.input.node.graph, - value=self.input.node.proxy_value, - target=Node.update, - args=[self.input.node, value], - ) - - @property - def graph(self) -> Graph: - if self._graph is None: - self._graph = Graph.trace( - self, - *util.apply( - self.input_shape, - lambda x: torch.empty(x, device="meta"), - torch.Size, - ), - ) - - return self._graph - - @staticmethod - def wrap(module: torch.nn.Module) -> Module: - """Wraps the torch Module with our Module - - Args: - module (torch.nn.Module): _description_ - - Returns: - Module: _description_ - """ - - def hook(module: Module, input: Any, output: Any): - module._output = None - module._input = None - module.output_shape = util.apply(output, lambda x: x.shape, torch.Tensor) - module.input_shape = util.apply(input, lambda x: x.shape, torch.Tensor) - module.output_type = util.apply(output, lambda x: x.dtype, torch.Tensor) - module.input_type = util.apply(input, lambda x: x.dtype, torch.Tensor) - - for name, _module in module.named_children(): - setattr(module, name, Module.wrap(_module)) - - if isinstance(module, (Module, torch.nn.ModuleList)): - return module - - util.wrap(module, Module) - - module.register_forward_hook(hook) - - return module diff --git a/engine/__init__.py b/engine/__init__.py deleted file mode 100644 index 098074b..0000000 --- a/engine/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -import yaml - -from .pydantics.Config import ConfigModel -from .patching import * - -PATH = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(PATH, "config.yaml"), "r") as file: - CONFIG = ConfigModel(**yaml.safe_load(file)) - -from .models.DiffuserModel import DiffuserModel -from .models.LanguageModel import LanguageModel -from .models.AbstractModel import AbstractModel -from .Module import Module diff --git a/engine/alteration/__init__.py b/engine/alteration/__init__.py deleted file mode 100644 index 307b0b7..0000000 --- a/engine/alteration/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .gpt import GPT2Patcher - -REPOID_TO_ALTERATION = {"gpt2": GPT2Patcher} diff --git a/engine/alteration/gpt.py b/engine/alteration/gpt.py deleted file mode 100644 index 60de416..0000000 --- a/engine/alteration/gpt.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Tuple, Union - -import torch -from transformers.models import gpt2 - -from .. import util -from ..patching import Patch, Patcher - - -class GPT2AttentionAltered(gpt2.modeling_gpt2.GPT2Attention): - def __init__(self, config, is_cross_attention=False, layer_idx=None): - super().__init__(config, is_cross_attention, layer_idx) - - self.query = util.WrapperModule() - self.key = util.WrapperModule() - self.value = util.WrapperModule() - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split( - self.split_size, dim=2 - ) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - # Altered ------------- - - query = self.query(query) - key = self.key(key) - value = self.value(value) - - # --------------------- - - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None - - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn( - query, key, value, attention_mask, head_mask - ) - else: - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) - - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -GPT2Patcher = Patcher([ - Patch(gpt2.modeling_gpt2.GPT2Attention, GPT2AttentionAltered) - ]) diff --git a/engine/config.yaml b/engine/config.yaml deleted file mode 100644 index 005b63d..0000000 --- a/engine/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -API: - HOST: localhost:5000 \ No newline at end of file diff --git a/engine/contexts/Generator.py b/engine/contexts/Generator.py deleted file mode 100644 index 56da67b..0000000 --- a/engine/contexts/Generator.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -import pickle -from typing import TYPE_CHECKING, Dict, List, Union - -import socketio - -from .. import CONFIG, pydantics -from ..fx.Graph import Graph -from ..intervention import InterventionProxy -from .Invoker import Invoker - -if TYPE_CHECKING: - from ..models.AbstractModel import AbstractModel - - -class Generator: - """_summary_ - - Attributes: - model (Model): Model object this is a generator for. - blocking (bool): If when using device_map='server', block and wait form responses. Otherwise have to manually - request a response. - args (List[Any]): Arguments for calling the model. - kwargs (Dict[str,Any]): Keyword arguments for calling the model. - generation_idx (int): Keeps track of what iteration of generation to do interventions at. Used by the Module class - to specify generation_idx for interventions and changed by the Invoker class using invoker.next(). - batch_size (int): Current size of invocation batch. To be used by Module node creation - prompts (List[str]): Keeps track of prompts used by invokers. - graph (Graph): Graph of all user intervention operations. - output (??): desc - """ - - def __init__( - self, - model: "AbstractModel", - *args, - blocking: bool = True, - server: bool = False, - **kwargs, - ) -> None: - self.model = model - self.server = server - self.blocking = blocking - self.args = args - self.kwargs = kwargs - - self.graph = Graph(self.model.meta_model, proxy_class=InterventionProxy) - - self.generation_idx: int = 0 - self.batch_size: int = 0 - self.prompts: List[str] = [] - self.output = None - - # Modules need to know about the current generator to create the correct proxies. - for name, module in self.model.named_modules(): - module.generator = self - - def __enter__(self) -> Generator: - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """On exit, run and generate using the model whether locally or on the server.""" - if self.server: - self.run_server() - else: - self.run_local() - - def run_local(self): - # Run the model and store the output. - self.output = self.model(self.model._generation, self.prompts, self.graph, *self.args, **self.kwargs) - - def run_server(self): - # Create the pydantic class for the request. - request = pydantics.RequestModel( - args=self.args, - kwargs=self.kwargs, - model_name=self.model.model_name_or_path, - prompts=self.prompts, - intervention_graph=self.graph, - ) - - if self.blocking: - self.blocking_request(request) - else: - self.non_blocking_request(request) - - def blocking_request(self, request: pydantics.RequestModel): - # Create a socketio connection to the server. - sio = socketio.Client() - sio.connect(f"ws://{CONFIG.API.HOST}") - - # Called when recieving a response from the server. - @sio.on("blocking_response") - def blocking_response(data): - # Load the data into the ResponseModel pydantic class. - data: pydantics.ResponseModel = pickle.loads(data) - - # Print response for user ( should be logger.info and have an infor handler print to stdout) - print(str(data)) - - # If the status of the response is completed, update the local futures that the user specified to save. - # Then disconnect and continue. - if data.status == pydantics.JobStatus.COMPLETED: - for name, value in data.saves.items(): - self.graph.nodes[name].future.set_result(value) - - self.output = data.output - - sio.disconnect() - # Or if there was some error. - elif data.status == pydantics.JobStatus.ERROR: - sio.disconnect() - - sio.emit( - "blocking_request", - request.model_dump(exclude_defaults=True, exclude_none=True), - ) - - sio.wait() - - def non_blocking_request(self, request: pydantics.RequestModel): - pass - - def invoke(self, input, *args, **kwargs) -> Invoker: - return Invoker(self, input, *args, **kwargs) diff --git a/engine/contexts/Invoker.py b/engine/contexts/Invoker.py deleted file mode 100644 index ef79386..0000000 --- a/engine/contexts/Invoker.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Dict - -from ..fx.Proxy import Proxy - -if TYPE_CHECKING: - from .Generator import Generator - - -class Invoker: - def __init__(self, generator: "Generator", input, *args, **kwargs) -> None: - self.generator = generator - self.input = input - self.args = args - self.kwargs = kwargs - self.tokens = None - self.ids = None - - def __enter__(self) -> Invoker: - # Were in a new invocation so set generation_idx to 0, - self.generator.generation_idx = 0 - - # Run graph_mode with meta tensors to collect shape information, - token_ids = self.generator.model._run_meta(self.input, *self.args, **self.kwargs) - - # Decode tokenized inputs for user usage. - self.tokens = [ - [self.generator.model.tokenizer.decode(token) for token in ids] - for ids in token_ids - ] - self.ids = token_ids - - self.generator.batch_size = len(self.ids) - - # Rebuild prompt from tokens (do this becuase if they input ids directly, we still need to pass - # all input data at once to a tokenizer to correctly batch the attention). - self.generator.prompts.extend(["".join(tokens) for tokens in self.tokens]) - - if len(self.tokens) == 1: - self.tokens = self.tokens[0] - self.ids = self.ids[0] - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - pass - - def next(self) -> None: - # .next() increases which generation idx the interventions happen. - self.generator.generation_idx += 1 - - # Run graph with singe token input. - self.generator.model._run_meta("_", *self.args, **self.kwargs) - - def save_all(self) -> Dict[str, Proxy]: - """Saves the output of all modules and returns a dictionary of [module_path -> save proxy] - - Returns: - Dict[str, Proxy]: _description_ - """ - result = {} - - for name, module in self.generator.model.meta_model.named_modules(): - result[module.module_path] = module.output.save() - - return result diff --git a/engine/contexts/Runner.py b/engine/contexts/Runner.py deleted file mode 100644 index c3c2a6a..0000000 --- a/engine/contexts/Runner.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Dict, List, Union - -from ..fx.Graph import Graph -from ..intervention import InterventionProxy - -if TYPE_CHECKING: - from ..models.AbstractModel import AbstractModel - - -# TODO make parent class for Runner and Generator as Module depends on attributes -class Runner: - def __init__( - self, - model: "AbstractModel", - input, - *args, - inference=False, - **kwargs, - ) -> None: - self.model = model - - self.input = input - self.inference = inference - self.args = args - self.kwargs = kwargs - - self.graph = Graph(self.model.meta_model, proxy_class=InterventionProxy) - - self.batch_size: int = 0 - self.prompts: List[str] = [] - self.generation_idx = 0 - self.output = None - - # Modules need to know about the current generator to create the correct proxies. - for name, module in self.model.named_modules(): - module.generator = self - - def __enter__(self) -> Runner: - token_ids = self.model._run_meta(self.input, *self.args, **self.kwargs) - - # Decode tokenized inputs for user usage. - self.tokens = [ - [self.model.tokenizer.decode(token) for token in ids] for ids in token_ids - ] - self.ids = token_ids - - self.batch_size = len(self.ids) - - self.prompts.extend(["".join(tokens) for tokens in self.tokens]) - - if len(self.tokens) == 1: - self.tokens = self.tokens[0] - self.ids = self.ids[0] - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.output = self.model( - self.model._run_local, - self.input, - self.graph, - *self.args, - inference=self.inference, - **self.kwargs, - ) diff --git a/engine/contexts/__init__.py b/engine/contexts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engine/editing/Editor.py b/engine/editing/Editor.py deleted file mode 100644 index 0eeb5cf..0000000 --- a/engine/editing/Editor.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations -from abc import abstractmethod -from typing import List - -import torch - -class Edit: - - @abstractmethod - def edit(self, obj: torch.nn.Module): - pass - - @abstractmethod - def restore(self, obj: torch.nn.Module): - pass - - -class Editor: - def __init__(self, obj: object, edits: List[Edit]) -> None: - self.obj = obj - self.edits = edits - - def __enter__(self) -> Editor: - for edit in self.edits: - edit.edit(self.obj) - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - for edit in self.edits: - edit.restore(self.obj) diff --git a/engine/editing/GraphEdit.py b/engine/editing/GraphEdit.py deleted file mode 100644 index c430e42..0000000 --- a/engine/editing/GraphEdit.py +++ /dev/null @@ -1,22 +0,0 @@ -from .Editor import Edit -import torch -from .. import util -from ..fx.Graph import Graph -class GraphEdit(Edit): - - def __init__(self, module_path:str, graph:Graph) -> None: - super().__init__() - - self.module_path = module_path - self.graph = graph - - self.forward = None - - def edit(self, model:torch.nn.Module): - module: torch.nn.Module = util.fetch_attr(model, self.module_path) - self.forward = module.forward - self.graph.wrap(module) - - def restore(self, model:torch.nn.Module): - module: torch.nn.Module = util.fetch_attr(model, self.module_path) - setattr(module, 'forward', self.forward) \ No newline at end of file diff --git a/engine/editing/WrapperModuleEdit.py b/engine/editing/WrapperModuleEdit.py deleted file mode 100644 index 991a44e..0000000 --- a/engine/editing/WrapperModuleEdit.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - -from .. import util -from .Editor import Edit - - -class WrapperModuleEdit(Edit): - def __init__(self, module_path: str, module_name: str) -> None: - super().__init__() - - self.module_path = module_path - self.module_name = module_name - - self.wrapper = util.WrapperModule() - - def edit(self, model: torch.nn.Module): - module: torch.nn.Module = util.fetch_attr(model, self.module_path) - setattr(module, self.module_name, self.wrapper) - - def restore(self, model: torch.nn.Module): - module: torch.nn.Module = util.fetch_attr(model, self.module_path) - delattr(module, self.module_name) diff --git a/engine/editing/__init__.py b/engine/editing/__init__.py deleted file mode 100644 index b974282..0000000 --- a/engine/editing/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import * \ No newline at end of file diff --git a/engine/fx/Graph.py b/engine/fx/Graph.py deleted file mode 100644 index 7fbd886..0000000 --- a/engine/fx/Graph.py +++ /dev/null @@ -1,294 +0,0 @@ -from __future__ import annotations - -import inspect -from typing import Any, Callable, Dict, List, Type, Union - -import torch - -from .. import util -from .Node import Node -from ..patching import Patcher, Patch -from .Proxy import Proxy, proxy_wrapper - - -class Graph: - """Represents a computation graph involving a Module - - Attributes: - proxy_class (Type[Proxy]): Proxy class to use. Defaults to Proxy. - nodes (Dict[str, Node]): Mapping of node name to node. - name_idx (Dict[str, int]): Mapping of node target_name to number of previous names with the same target_name. - Used so names are unique. - module_proxy (Proxy): Proxy for given root module - argument_node_names (Dict[str, List[str]]): _description_ - generation_idx (int): _description_ - - """ - - @staticmethod - def trace( - module: torch.nn.Module, *args: List[Any], **kwargs: Dict[str, Any] - ) -> Graph: - """Given a module and some default (should be meta tensors) arguments, create a graph from the module's - forward method. - - Args: - module (torch.nn.Module): _description_ - args (List[Any]): desc - kwargs (Dict[str, Any]): desc - - Returns: - Graph: _description_ - """ - - # Create a graph with the module as the root module - graph = Graph(module) - - # Get 'unbound' version of forward method so we can pass in proxy of module insead of self - forward = module.__class__.forward - - # Want list not tuple - args = list(args) - - # Inspect forward signature to collect all parameters - signature = inspect.signature(forward) - - def get_argument_value(param: inspect.Parameter, idx: int): - """Gets the correct argument to pass to forward method. - - - Args: - param (_type_): _description_ - idx (_type_): _description_ - - Returns: - _type_: _description_ - """ - - # If idx in range of provided args, create a proxy for that arg instead of default. - if idx < len(args): - return graph.add( - graph=graph, value=args[idx], target="argument", args=[param.name] - ) - # If param name in provided kwargs, create a proxy for that arg instead of default. - if param.name in kwargs: - return graph.add( - graph=graph, - value=kwargs[param.name], - target="argument", - args=[param.name], - ) - # Otherwise just return default - return param.default - - # Create the appropriate proxies/values for the forward method in order to trace. - arguments = [ - get_argument_value(param, i) - for i, param in enumerate(list(signature.parameters.values())[1:]) - ] - - # Some methods cannot be caught because they arent torch functions or dont play nice with __torch_function__. - # So the patcher repalces the methods with something to catch proxies and return proxies. - with Patcher() as patcher: - patcher.add(Patch(torch.full, proxy_wrapper(torch.full))) - patcher.add(Patch(torch.finfo, proxy_wrapper(torch.finfo))) - patcher.add(Patch(torch.arange, proxy_wrapper(torch.arange))) - - # Run forward with root module proxy and arguments - output: Proxy = forward(graph.module_proxy, *arguments) - - # Get proxy_value for return - value = util.apply(output, lambda x: x.node.proxy_value, Proxy) - - # Create the 'rtn_0' return proxy - return_proxy = graph.add( - graph=graph, value=value, target=Graph.rtn, args=output - ) - - # This is how we tell the graph not to destroy a proxy after it's listeners are completed. - # Create a 'null' proxy. The return proxy listens to the 'null' proxy with args=[return_proxy.node] but 'null' will never be completed. - graph.add( - graph=graph, - value=None, - target="null", - args=[return_proxy.node], - ) - - return graph - - @staticmethod - def rtn(*args, **kwargs): - """ - Function to just pass through data for returning data in a graph forward method. - - Returns: - _type_: _description_ - """ - - return args - - def __init__( - self, module: torch.nn.Module, proxy_class: Type[Proxy] = Proxy - ) -> None: - """_summary_ - - Args: - module (torch.nn.Module): _description_ - proxy_class (Type[Proxy], optional): _description_. - """ - self.proxy_class = proxy_class - - self.nodes: Dict[str, Node] = dict() - self.name_idx: Dict[str, int] = dict() - - self.module_proxy = self.add(graph=self, value=module, target="module") - self.argument_node_names: Dict[str, List[str]] = dict() - - self.generation_idx = 0 - - def increment(self) -> None: - """Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation.""" - self.generation_idx += 1 - - def compile(self, module: torch.nn.Module) -> None: - """Re-compile graph to prepare for a new execution of the graph. - - Args: - module (torch.nn.Module): Module to be considered the root module of the graph. - """ - - # Remove nodes that have no effect. - self.eliminate_dead_code() - - # Reset all node futures. - for node in self.nodes.values(): - node._future = None - # Compile nodes individually. - for node in self.nodes.values(): - node.compile() - - self.generation_idx = 0 - - # Setting the root module future kicks off the graph execution. - self.nodes["module_0"].future.set_result(module) - - def add( - self, - graph: Graph, - value: Any, - target: Union[Callable, str], - args: List[Any] = None, - kwargs: Dict[str, Any] = None, - name: str = None, - ) -> Proxy: - """Adds a node to the graph and returns it's proxy. - - Args: - graph (Graph): _description_ - value (Any): 'meta' proxy value used for tracing the shapes and values. - target (Union[Callable, str]): Either the function to call for this node, or a string that's the name of a method attribute on the first arg. - args (List[Any], optional): _description_. Defaults to None. - kwargs (Dict[str, Any], optional): _description_. Defaults to None. - name (str, optional): _description_. Defaults to None. - - Returns: - Proxy: _description_ - """ - target_name = Node.target_name(target) - - if target_name not in self.name_idx: - self.name_idx[target_name] = 0 - else: - if target_name == "rtn": - raise ValueError("Can only have one return ('rtn') node.") - if target_name == "module": - raise ValueError("Can only have one module node.") - - if name is None: - name = f"{target_name}_{self.name_idx[target_name]}" - - self.name_idx[target_name] += 1 - - stack = inspect.stack() - proxy_frame = stack[2] - - node = Node( - name=name, - graph=graph, - value=value, - target=target, - args=args, - kwargs=kwargs, - meta={"line": proxy_frame.lineno, "file": proxy_frame.filename}, - ) - - self.nodes[name] = node - - if target_name == "argument": - module_path = args[0] - - if module_path not in self.argument_node_names: - self.argument_node_names[module_path] = [] - - self.argument_node_names[module_path].append(name) - - return self.proxy(node) - - def proxy(self, node: Node) -> Proxy: - """Returns proxy of node with specified proxy_class. - - Args: - node (Node): _description_ - - Returns: - Proxy: _description_ - """ - return self.proxy_class(node) - - def eliminate_dead_code(self): - # TODO - pass - - def wrap(self, module: torch.nn.Module) -> torch.nn.Module: - """Replaces the forward method of the given module with an execution of the module's graph. - - Args: - module (torch.nn.Module): _description_ - - Returns: - torch.nn.Module: _description_ - """ - - def forward(*args, **kwargs): - # Compile the graph with the given module as the root module. - self.compile(module) - - # Gets list of all argument nodes for this graph. - argument_nodes_list = list(self.argument_node_names.values()) - - # Sets the result of the argument nodes future for args. - for i, arg in enumerate(args): - self.nodes[argument_nodes_list[i][0]].future.set_result(arg) - - # And then for kwargs. - for key in kwargs: - if key in self.argument_node_names: - self.nodes[self.argument_node_names[key][0]].future.set_result(arg) - - # 'rtn_0' should have the value we need to return. - return_value = self.nodes["rtn_0"].value() - self.nodes["rtn_0"].destroy() - return return_value - - # Repalce forward method with custom graph execution method. - module.forward = forward - - return module - - def __str__(self) -> str: - result = "" - - for name, node in self.nodes.items(): - result += f" %{node}\n" - - return result diff --git a/engine/fx/Node.py b/engine/fx/Node.py deleted file mode 100644 index c74487a..0000000 --- a/engine/fx/Node.py +++ /dev/null @@ -1,277 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union - -import torch.futures - -from .. import util -from ..logger import logger -from .Proxy import Proxy - -if TYPE_CHECKING: - from .Graph import Graph - - -class Node: - """_summary_ - - Attributes: - name (str): _description_ - graph (Graph): _description_ - proxy_value (Any): _description_ - target (Union[Callable, str]): _description_ - args (List[Any], optional): _description_. Defaults to None. - kwargs (Dict[str, Any], optional): _description_. Defaults to None. - meta (Dict[str, Any], optional): _description_. Defaults to None. - listeners (List[Node]): desc - dependencies (List[Node]): desc - _future (torch.futures.Future): desc - _proxy_device (torch.device): desc - """ - - @staticmethod - def update(value1, value2) -> None: - """Updates Tensor values with other Tensor values. - - Args: - value1 (_type_): _description_ - value2 (_type_): _description_ - """ - if isinstance(value1, torch.Tensor): - value1[:] = value2 - elif isinstance(value1, list) or isinstance(value1, tuple): - for value_idx in range(len(value1)): - Node.update(value1[value_idx], value2[value_idx]) - elif isinstance(value1, dict): - for key in value1: - Node.update(value1[key], value2[key]) - - @staticmethod - def target_name(target) -> str: - if isinstance(target, str): - name = target - elif callable(target): - name = target.__name__ - - return name - - def __init__( - self, - name: str, - graph: "Graph", - value: Any, - target: Union[Callable, str], - args: List[Any] = None, - kwargs: Dict[str, Any] = None, - meta: Dict[str, Any] = None, - ) -> None: - super().__init__() - - if args is None: - args = list() - if kwargs is None: - kwargs = dict() - if meta is None: - meta = dict() - - self.name = name - self.graph = graph - self.proxy_value = value - self.target = target - self.args = util.apply(args, lambda x: x.node, Proxy) - self.kwargs = util.apply(kwargs, lambda x: x.node, Proxy) - self.meta = meta - - self.listeners: List[Node] = list([self]) - self.dependencies: List[Node] = list() - - # Add all arguments that are nodes to nodes dependencies - util.apply(self.args, lambda x: self.dependencies.append(x), Node) - util.apply(self.kwargs, lambda x: self.dependencies.append(x), Node) - # Add node to all arguments that are nodes' listeners - util.apply(self.args, lambda x: x.listeners.append(self), Node) - util.apply(self.kwargs, lambda x: x.listeners.append(self), Node) - - self._future: torch.futures.Future = None - self._proxy_device: torch.device = None - - @property - def future(self) -> torch.futures.Future: - """Lazy creation of _future attribute. - - Returns: - torch.futures.Future: _description_ - """ - if self._future is None: - self._future = torch.futures.Future() - - return self._future - - @property - def proxy_device(self) -> torch.device: - """Lazy creation of _proxy_device attribute. - - Returns: - torch.Device: _description_ - """ - if self._proxy_device is None: - device = None - - def _device(value): - nonlocal device - device = value.device - - util.apply(self.proxy_value, _device, torch.Tensor) - # TODO - # util.apply(self.proxy_value, _device, torch.nn.Module) - - self._proxy_device = device - - return self._proxy_device - - def prepare_proxy_values(self, values): - def slice_to_value(arg: slice): - return slice( - self.prepare_proxy_values(arg.start), - self.prepare_proxy_values(arg.stop), - self.prepare_proxy_values(arg.step), - ) - - # Convert procies to their proxy_value - values = util.apply(values, lambda x: x.node.proxy_value, Proxy) - # Slices may have proxies as part of their attributes so convert those to their proxy_values - values = util.apply(values, slice_to_value, slice) - # Move tensors to that of the proxy_device (probably 'meta') - values = util.apply(values, lambda x: x.to(self.proxy_device), torch.Tensor) - - return values - - def compile(self) -> None: - # When this future is done, log that event. - self.future.add_done_callback(lambda x: logger.debug(f"=> SET({self.name})")) - - # Nodes tell listeners when to try and be executed. - # This chains futures so after this node's future is done, it goes through - # it's listeners in order and calls their .chain() method. - future = self.listeners[0].future - - for listener in self.listeners[1:]: - future = future.then(listener.chain) - - # Collect all listeners futures into a single future that when done, call this - # nodes .destroy() method. - torch.futures.collect_all( - util.apply(self.listeners, lambda x: x.future, Node) - ).add_done_callback(lambda x: self.destroy()) - - def value(self) -> Any: - """Wrapper for this node's future .value() - - Returns: - Any: _description_ - """ - return self.future.value() - - def done(self) -> bool: - """Wrapper for this node's future .done() - - Returns: - bool: _description_ - """ - return self.future.done() - - def fufilled(self) -> bool: - """Returns True if all of this node's dependencies are done. - - Returns: - bool: _description_ - """ - for dependency in self.dependencies: - if not dependency.done(): - return False - - return True - - def prepare_inputs(self) -> Tuple[List[Any], Dict[str, Any]]: - # Turn futures into their value - def _value(value: Node): - return value.value() - - args = util.apply(self.args, _value, Node) - kwargs = util.apply(self.kwargs, _value, Node) - - device = None - - def _device(value): - nonlocal device - device = value.device - - all_args = list(args) + list(kwargs.values()) - - util.apply(list(reversed(all_args)), _device, torch.Tensor) - # util.apply(list(reversed(all_args)), _device, torch.nn.Module) - - # Move tensors to device - def _to(value: torch.Tensor): - return value.to(device) - - args = util.apply(args, _to, torch.Tensor) - kwargs = util.apply(kwargs, _to, torch.Tensor) - - return args, kwargs - - def execute(self) -> None: - """Actually executes this node.""" - - # We se a nodes target to 'null' if we don't want it to be executed and therefore never done - if self.target == "null": - return - - # Prepare arguments. - args, kwargs = self.prepare_inputs() - - # If target is a string, it must be a method attribute on the first argument object. - if isinstance(self.target, str): - obj, *args = args - - target = getattr(obj, self.target) - # Otherwise it must be the function itself. - else: - target = self.target - - # Call the target to get value. - output = target(*args, **kwargs) - - # Set this nodes future value to result. - self.future.set_result(output) - - def destroy(self) -> None: - """Removes the reference to the node's _future and logs it's destruction.""" - logger.debug(f"=> DEL({self.name})") - - self._future = None - - def chain(self, future: torch.futures.Future): - # If all of a nodes dependencies are done, execute it. - # Dont execute if already done. - if self.fufilled() and not self.done(): - try: - self.execute() - except Exception as e: - # TODO - # An exectption is actually never thrown upward to the point it stops the program. Need to find a way. - logger.exception(f"Exception in execution of node '{self.name}'.") - - self.future.set_exception(e) - future.set_exception(e) - - raise e - - future.set_result(None) - - def __str__(self) -> str: - args = util.apply(self.args, lambda x: f"'{x}'", str) - args = util.apply(args, lambda x: x.name, Node) - args = [str(arg) for arg in args] - meta = f"{self.meta['file']}({self.meta['line']})" if self.meta else "" - return f"{self.name}:[ {meta} args:({','.join(args)}) l:{len(self.listeners)} d:{len(self.dependencies)}]" diff --git a/engine/fx/Proxy.py b/engine/fx/Proxy.py deleted file mode 100644 index 8c0ecde..0000000 --- a/engine/fx/Proxy.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Union - -import torch - -from .. import util - -if TYPE_CHECKING: - from .Node import Node - - -class Proxy: - """_summary_ - - Attributes: - node (Node): desc - """ - - def __init__(self, node: "Node") -> None: - self.node = node - - def __call__(self, *args, **kwargs) -> Proxy: - # If calling a method (not a sub-module) on the main module of this graph, - # we want to trace into that method. - if self.node.args[0] is self.node.graph.module_proxy.node and not isinstance( - self.node.proxy_value, torch.nn.Module - ): - value = self.node.proxy_value.__func__( - self.node.graph.module_proxy, *args, **kwargs - ) - - return value - # Otherwise we just want to add a node saying we wish to call this module. - else: - value = self.node.proxy_value( - *self.node.prepare_proxy_values(args), - **self.node.prepare_proxy_values(kwargs), - ) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__call__", - args=[self.node] + list(args), - kwargs=kwargs, - ) - - def __getitem__(self, key: Union[Proxy, Any]) -> Proxy: - key = self.node.prepare_proxy_values(key) - - value = self.node.proxy_value[key] - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__getitem__", - args=[self.node, key], - ) - - def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None: - item_proxy = self[key] - - update = item_proxy.node.__class__.update - - update(item_proxy.node.proxy_value, item_proxy.node.prepare_proxy_values(value)) - - item_proxy.node.graph.add( - graph=item_proxy.node.graph, - value=item_proxy.node.proxy_value, - target=update, - args=[item_proxy.node, value], - ) - - def __getattr__(self, key: Union[Proxy, Any]) -> Proxy: - key = self.node.prepare_proxy_values(key) - - value = util.fetch_attr(self.node.proxy_value, key) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target=util.fetch_attr, - args=[self.node, key], - ) - - def __len__(self) -> Proxy: - value = len(self.node.proxy_value) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target=len, - args=[self.node], - ) - - def __add__(self, other: Union[Proxy, Any]) -> Proxy: - value = self.node.proxy_value + self.node.prepare_proxy_values(other) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__add__", - args=[self.node, other], - ) - - def __sub__(self, other: Union[Proxy, Any]) -> Proxy: - value = self.node.proxy_value - self.node.prepare_proxy_values(other) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__sub__", - args=[self.node, other], - ) - - def __pow__(self, other: Union[Proxy, Any]) -> Proxy: - value = self.node.proxy_value ** self.node.prepare_proxy_values(other) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target=pow, - args=[self.node, other], - ) - - def __mul__(self, other: Union[Proxy, Any]) -> Proxy: - value = self.node.proxy_value * self.node.prepare_proxy_values(other) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__mul__", - args=[self.node, other], - ) - - def __truediv__(self, other: Union[Proxy, Any]) -> Proxy: - value = self.node.proxy_value / self.node.prepare_proxy_values(other) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target="__truediv__", - args=[self.node, other], - ) - - def __bool__(self) -> bool: - return self.node.proxy_value.__bool__() - - def __index__(self) -> int: - return self.node.proxy_value.__index__() - - def __instancecheck__(self, __instance: Any) -> bool: - return self.node.proxy_value.__instancecheck__(__instance) - - @classmethod - def __torch_function__(cls, orig_method, types, args=None, kwargs=None) -> Proxy: - if args is None: - args = list() - if kwargs is None: - kwargs = dict() - - self: Proxy = args[0] - - value = orig_method( - *self.node.prepare_proxy_values(args), - **self.node.prepare_proxy_values(kwargs), - ) - - return self.node.graph.add( - graph=self.node.graph, - value=value, - target=orig_method, - args=args, - kwargs=kwargs, - ) - - -from functools import wraps - - -def proxy_wrapper(fn) -> None: - """Wraps problematic functions (torch functions sometimes). - Checks if anty of its args are proxies. If so we return a proxy of the function. - Otherwise just run the function. - - Args: - fn (function): _description_ - - Returns: - _type_: _description_ - """ - - @wraps(fn) - def patched(*args, **kwargs): - arguments = list(args) + list(kwargs.values()) - - node = None - - for arg in arguments: - if isinstance(arg, Proxy): - node = arg.node - - break - - if node is not None: - value = fn( - *node.prepare_proxy_values(args), - **node.prepare_proxy_values(kwargs), - ) - - return node.graph.add( - graph=node.graph, value=value, target=fn, args=args, kwargs=kwargs - ) - - else: - return fn(*args, **kwargs) - - return patched diff --git a/engine/fx/__init__.py b/engine/fx/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engine/intervention.py b/engine/intervention.py deleted file mode 100644 index 74eeb80..0000000 --- a/engine/intervention.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable, List, Tuple, Union - -import torch.futures - -from . import util -from .fx.Graph import Graph -from .fx.Proxy import Proxy - - -class TokenIndexer: - def __init__(self, proxy: InterventionProxy) -> None: - self.proxy = proxy - - def convert_idx(self, idx: int): - if idx >= 0: - n_tokens = self.proxy.node.proxy_value.shape[1] - idx = -(n_tokens - idx) - - return idx - - def __getitem__(self, key: int) -> Proxy: - key = self.convert_idx(key) - - return self.proxy[:, key] - - def __setitem__(self, key: int, value: Union[Proxy, Any]) -> None: - key = self.convert_idx(key) - - self.proxy[:, key] = value - - -class InterventionProxy(Proxy): - @staticmethod - def proxy_save(value: Any) -> None: - return util.apply(value, lambda x: x.clone(), torch.Tensor) - - def save(self) -> InterventionProxy: - proxy = self.node.graph.add( - graph=self.node.graph, - value=self.node.proxy_value, - target=InterventionProxy.proxy_save, - args=[self.node], - ) - - self.node.graph.add( - graph=self.node.graph, - value=None, - target="null", - args=[proxy.node], - ) - - return proxy - - @property - def token(self) -> TokenIndexer: - return TokenIndexer(self) - - @property - def t(self) -> TokenIndexer: - return self.token - - @property - def shape(self): - return util.apply(self.node.proxy_value, lambda x: x.shape, torch.Tensor) - - @property - def value(self): - return self.node.future.value() - - -def intervene(activations, module_path: str, graph: Graph, key: str): - """Entry to intervention graph. This should be hooked to all modules involved in intervention graph. - - Args: - activations (_type_): _description_ - module_path (str): _description_ - graph (Graph): _description_ - key (str): _description_ - - Returns: - _type_: _description_ - """ - - # Key to module activation argument nodes has format: ... - module_path = f"{module_path}.{key}.{graph.generation_idx}" - - if module_path in graph.argument_node_names: - argument_node_names = graph.argument_node_names[module_path] - - # multiple argument nodes can have same module_path if there are multiple invocations. - for argument_node_name in argument_node_names: - node = graph.nodes[argument_node_name] - - # args for argument nodes are (module_path, batch_size, batch_start) - _, batch_size, batch_start = node.args - - # We set its result to the activations, indexed by only the relevant batch idxs. - node.future.set_result( - util.apply( - activations, - lambda x: x.narrow(0, batch_start, batch_size), - torch.Tensor, - ) - ) - - return activations - - -class HookModel: - def __init__( - self, - model: torch.nn.Module, - modules: List[str], - input_hook: Callable = None, - output_hook: Callable = None, - ) -> None: - self.model = model - self.modules: List[Tuple[torch.nn.Module, str]] = [ - (util.fetch_attr(self.model, module_path), module_path) - for module_path in modules - ] - self.input_hook = input_hook - self.output_hook = output_hook - - self.handles = [] - - def __enter__(self) -> HookModel: - for module, module_path in self.modules: - if self.input_hook is not None: - - def input_hook(module, input, module_path=module_path): - return self.input_hook(input, module_path) - - self.handles.append(module.register_forward_pre_hook(input_hook)) - - if self.output_hook is not None: - - def output_hook(module, input, output, module_path=module_path): - return self.output_hook(output, module_path) - - self.handles.append(module.register_forward_hook(output_hook)) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - for handle in self.handles: - handle.remove() diff --git a/engine/logger.py b/engine/logger.py deleted file mode 100644 index a859209..0000000 --- a/engine/logger.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging -import os - -PATH = os.path.dirname(os.path.abspath(__file__)) -logging_handler = logging.FileHandler(os.path.join(PATH, f"engine.log"), "a") -logging_handler.setFormatter( - logging.Formatter( - "%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s" - ) -) -logging_handler.setLevel(logging.DEBUG) -logger = logging.getLogger("engine") -logger.addHandler(logging_handler) -logger.setLevel(logging.DEBUG) diff --git a/engine/models/AbstractModel.py b/engine/models/AbstractModel.py deleted file mode 100644 index b66fb86..0000000 --- a/engine/models/AbstractModel.py +++ /dev/null @@ -1,244 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import Any, Callable, List, Union - -import accelerate -import torch -from torch.utils.hooks import RemovableHandle - -from ..alteration import REPOID_TO_ALTERATION -from ..contexts.Generator import Generator -from ..contexts.Runner import Runner -from ..editing.Editor import Edit, Editor -from ..editing.GraphEdit import GraphEdit -from ..editing.WrapperModuleEdit import WrapperModuleEdit -from ..fx.Graph import Graph -from ..intervention import intervene, HookModel -from ..logger import logger -from ..Module import Module -from ..patching import Patcher - - -class AbstractModel: - """_summary_ - """ - def __init__(self, repoid_or_path:str, *args, alter:bool=True, **kwargs) -> None: - super().__init__() - - # TODO handle passing in a torch module - self.repoid_or_path = repoid_or_path - self.args = args - self.kwargs = kwargs - # Boolean on whether to check if alterations exist for this module and apply them. - self.alter = alter - # Boolean on whether this model has been dispatched (locally loaded) yet - self.dispatched = False - self.local_model: torch.nn.Module = None - self.edits: List[Edit] = list() - - logger.debug(f"Initializing `{self.repoid_or_path}`...") - - # If alter and alteration exist, use alteration patcher context while loading module. - with self.alteration() if self.alter else Patcher(): - # Use accelerate and .to('meta') to assure tensors are loaded to 'meta' device - with accelerate.init_empty_weights(include_buffers=True): - self.meta_model: torch.nn.Module = Module.wrap( - self._load_meta(self.repoid_or_path, *args, **kwargs).to("meta") - ) - - # Wrap all modules in our Module class. - for name, module in self.meta_model.named_children(): - module = Module.wrap(module) - - setattr(self.meta_model, name, module) - - # Set module_path attribute so Modules know their place. - for name, module in self.meta_model.named_modules(): - module.module_path = name - - # Run inital dummy string to populate Module shapes, dtypes etc - self._run_meta("_") - - logger.debug(f"Initialized `{self.repoid_or_path}`") - - def __repr__(self) -> str: - return repr(self.meta_model) - - def __getattr__(self, key) -> Any: - """Allows access of sub-modules on meta_model directly from AbstractModel object - - Args: - key (_type_): _description_ - - Returns: - Any: _description_ - """ - return getattr(self.meta_model, key) - - def __call__( - self, - fn: Callable, - inputs: Any, - graph: Graph, - *args, - edits: List[Edit] = None, - inference: bool = True, - **kwargs, - ) -> Any: - """Runs some function with some inputs and some graph with the approriate context for this model. - - Args: - fn (Callable): _description_ - inputs (Any): _description_ - graph (Graph): _description_ - edits (List[Edit], optional): _description_. Defaults to None. - inference (bool, optional): _description_. Defaults to True. - - Returns: - Any: _description_ - """ - if edits is None: - edits = self.edits - - - # If local_model not yet loaded, do so. - if not self.dispatched: - with self.alteration() if self.alter else Patcher(): - self.local_model = self._load_local( - self.repoid_or_path, *self.args, **self.kwargs - ) - - # By default, all params should be frozen. - for param in self.local_model.parameters(): - param.requires_grad = False - - - with Editor(self, edits): - - # Send local_model to graph to re-compile - graph.compile(self.local_model) - - increment_hook = self._register_increment_hook( - lambda module, input, output: graph.increment() - ) - - # The intervention graph for running a Model will have the modules that are involved - # in the graph's argument_node_names. - modules = set( - [ - ".".join(name.split(".")[:-2]) - for name in graph.argument_node_names.keys() - ] - ) - - logger.debug(f"Running `{self.repoid_or_path}`...") - - self.local_model.eval() if inference else self.local_model.train() - - with torch.inference_mode(mode=inference): - with HookModel( - self.local_model, - list(modules), - input_hook=lambda activations, module_path: intervene( - activations, module_path, graph, "input" - ), - output_hook=lambda activations, module_path: intervene( - activations, module_path, graph, "output" - ), - ): - output = fn(inputs, *args, **kwargs) - - increment_hook.remove() - - logger.debug(f"Completed `{self.repoid_or_path}`") - - return output - - def alteration(self) -> Patcher: - return REPOID_TO_ALTERATION.get(self.repoid_or_path, Patcher()) - - def generate(self, *args, **kwargs) -> Generator: - return Generator(self, *args, **kwargs) - - def forward(self, inputs, *args, **kwargs) -> Runner: - return Runner(self, inputs, *args, **kwargs) - - def modulize(self, module: Module, node_name: str, module_name: str) -> None: - """_summary_ - - Args: - module (Module): _description_ - node_name (str): _description_ - module_name (str): _description_ - """ - - # Create a WrapperModuleEdit which just adds a WrapperModule to an existing module at the given moduel_name. - wme = WrapperModuleEdit(module.module_path, module_name) - # Wrap with our Module and update new attributes. - wme.wrapper: Module = Module.wrap(wme.wrapper) - wme.wrapper.module_path = f"{module.module_path}.{module_name}" - wme.wrapper.generator = module.generator - wme.wrapper.output_shape = module.output_shape - # Carry out the edit on the meta_model. - wme.edit(self.meta_model) - - # Get/create the execution graph for the module's forward method. - graph = module.graph - - # Add two proxies/nodes, one to get the new WrapperModule we added and another to call it with the data from the original module. - # Passing the data through the wrapper module allows hooking of the module's output like usual. - module_proxy = getattr(graph.module_proxy, module_name) - module_proxy(graph.nodes[node_name]) - - # Create and carry out the edit on the meta_model. - ge = GraphEdit(module.module_path, module.graph) - ge.edit(self.meta_model) - - # Append to self.edits so when we call the local model, we temporarily edit the module in the same way as the meta model. - self.edits.append(wme) - self.edits.append(ge) - - @abstractmethod - def _prepare_inputs(self, inputs: Any, **kwargs) -> Any: - """Abstract method for Model type to process inputs. - - Args: - inputs (Any): _description_ - - Returns: - Any: _description_ - """ - raise NotImplementedError() - - @abstractmethod - def _load_meta(self, repoid_or_path, *args, **kwargs) -> torch.nn.Module: - """Abstract method for Model type to initialize what it needs for it's meta model. - - Args: - repoid_or_path (_type_): _description_ - - Returns: - torch.nn.Module: _description_ - """ - raise NotImplementedError() - - @abstractmethod - def _load_local(self, repoid_or_path, *args, **kwargs) -> torch.nn.Module: - raise NotImplementedError() - - @abstractmethod - def _run_meta(self, inputs, *args, **kwargs) -> Any: - raise NotImplementedError() - - @abstractmethod - def _run_local(self, inputs, *args, **kwargs) -> Any: - raise NotImplementedError() - - @abstractmethod - def _generation(self, inputs, *args, **kwargs) -> Any: - raise NotImplementedError() - - @abstractmethod - def _register_increment_hook(self, hook) -> RemovableHandle: - raise NotImplementedError() diff --git a/engine/models/DiffuserModel.py b/engine/models/DiffuserModel.py deleted file mode 100644 index db06328..0000000 --- a/engine/models/DiffuserModel.py +++ /dev/null @@ -1,247 +0,0 @@ -from __future__ import annotations - -from typing import Any, List, Union - -import diffusers -import torch -from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel -from PIL import Image -from torch.utils.hooks import RemovableHandle -from transformers import CLIPTextModel, CLIPTokenizer - -from .AbstractModel import AbstractModel - - -class Diffuser(torch.nn.Module): - def __init__( - self, repoid_or_path, tokenizer: CLIPTokenizer, *args, **kwargs - ) -> None: - super().__init__() - - self.tokenizer = tokenizer - - self.vae: AutoencoderKL = AutoencoderKL.from_pretrained( - repoid_or_path, *args, **kwargs, subfolder="vae" - ) - self.unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( - repoid_or_path, *args, **kwargs, subfolder="unet" - ) - self.text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( - repoid_or_path, *args, **kwargs, subfolder="text_encoder" - ) - - def get_text_embeddings(self, text_tokens, n_imgs) -> torch.Tensor: - text_ids = text_tokens.input_ids.to(self.text_encoder.device) - - text_embeddings = self.text_encoder(text_ids)[0] - - unconditional_tokens = self.text_tokenize([""] * len(text_ids)) - - unconditional_ids = unconditional_tokens.input_ids.to(self.text_encoder.device) - - unconditional_embeddings = self.text_encoder(unconditional_ids)[0] - - text_embeddings = torch.repeat_interleave( - torch.cat([unconditional_embeddings, text_embeddings]), n_imgs, dim=0 - ) - - return text_embeddings - - def text_tokenize(self, prompts): - return self.tokenizer( - prompts, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - def text_detokenize(self, tokens): - return [ - self.tokenizer.decode(token) - for token in tokens - if token != self.tokenizer.vocab_size - 1 - ] - - def get_noise(self, batch_size, img_size) -> torch.Tensor: - return torch.randn( - (batch_size, self.unet.config.in_channels, img_size // 8, img_size // 8) - ) - - def get_initial_latents(self, n_imgs, img_size, n_prompts) -> torch.Tensor: - latents = self.get_noise(n_imgs, img_size).repeat(n_prompts, 1, 1, 1) - - return latents - - def decode(self, latents): - return self.vae.decode(1 / 0.18215 * latents).sample - - def encode(self, tensors): - return self.vae.encode(tensors).latent_dist.mode() * 0.18215 - - def predict_noise( - self, scheduler, iteration, latents, text_embeddings, guidance_scale=7.5 - ): - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - latents = torch.cat([latents] * 2) - latents = scheduler.scale_model_input(latents, scheduler.timesteps[iteration]) - - # predict the noise residual - noise_prediction = self.unet( - latents, - scheduler.timesteps[iteration], - encoder_hidden_states=text_embeddings, - ).sample - - # perform guidance - noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2) - noise_prediction = noise_prediction_uncond + guidance_scale * ( - noise_prediction_text - noise_prediction_uncond - ) - - return noise_prediction - - def diffusion( - self, - scheduler, - latents, - text_embeddings, - end_iteration=1000, - start_iteration=0, - **kwargs, - ): - for iteration in range(start_iteration, end_iteration): - noise_pred = self.predict_noise( - scheduler, iteration, latents, text_embeddings, **kwargs - ) - - # compute the previous noisy sample x_t -> x_t-1 - output = scheduler.step(noise_pred, scheduler.timesteps[iteration], latents) - - latents = output.prev_sample - - return latents - - -class DiffuserModel(AbstractModel): - def __init__(self, *args, **kwargs) -> None: - self.local_model: Diffuser = None - self.meta_model: Diffuser = None - self.tokenizer: CLIPTokenizer = None - - super().__init__(*args, **kwargs) - - def _register_increment_hook(self, hook) -> RemovableHandle: - return self.local_model.unet.register_forward_hook(hook) - - def _load_meta( - self, repoid_or_path, *args, device="cpu", **kwargs - ) -> torch.nn.Module: - self.tokenizer = CLIPTokenizer.from_pretrained( - repoid_or_path, *args, **kwargs, subfolder="tokenizer" - ) - - return Diffuser(repoid_or_path, self.tokenizer, *args, **kwargs) - - def _load_local( - self, repoid_or_path, *args, device="cpu", **kwargs - ) -> torch.nn.Module: - return Diffuser(repoid_or_path, self.tokenizer, *args, **kwargs).to(device) - - def _prepare_inputs( - self, - inputs, - n_imgs=1, - img_size=512, - ) -> Any: - if not isinstance(inputs, list): - inputs = [inputs] - - latents = self.meta_model.get_initial_latents(n_imgs, img_size, len(inputs)) - - text_tokens = self.meta_model.text_tokenize(inputs) - - return text_tokens, latents - - def _run_meta(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None: - text_tokens, latents = self._prepare_inputs( - inputs, n_imgs=n_imgs, img_size=img_size - ) - - text_embeddings = self.meta_model.get_text_embeddings(text_tokens, n_imgs) - - latents = torch.cat([latents] * 2).to("meta") - - self.meta_model.unet( - latents, - torch.zeros((1,), device="meta"), - encoder_hidden_states=text_embeddings, - ).sample - - self.meta_model.vae.decode(latents) - - return text_tokens.input_ids - - def _run_local(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None: - text_tokens, latents = self._prepare_inputs( - inputs, n_imgs=n_imgs, img_size=img_size - ) - - text_embeddings = self.meta_model.get_text_embeddings(text_tokens, n_imgs) - - latents = torch.cat([latents] * 2).to("meta") - - return self.meta_model.unet( - latents, - torch.zeros((1,), device="meta"), - encoder_hidden_states=text_embeddings, - ).sample - - def _generation( - self, - inputs, - *args, - n_steps=20, - scheduler="LMSDiscreteScheduler", - n_imgs=1, - img_size=512, - **kwargs, - ) -> None: - text_tokens, latents = self._prepare_inputs( - inputs, n_imgs=n_imgs, img_size=img_size - ) - - text_embeddings = self.local_model.get_text_embeddings(text_tokens, n_imgs) - - if isinstance(scheduler, str): - scheduler: SchedulerMixin = getattr(diffusers, scheduler).from_pretrained( - self.repoid_or_path, subfolder="scheduler" - ) - scheduler.set_timesteps(n_steps) - - latents = latents * scheduler.init_noise_sigma - - latents = self.local_model.diffusion( - scheduler, - latents.to(self.local_model.unet.device), - text_embeddings.to(self.local_model.unet.device), - *args, - **kwargs, - end_iteration=n_steps, - ) - - latents = (1 / 0.18215) * latents - - return self.local_model.vae.decode(latents).sample - - def to_image(self, latents) -> List[Image.Image]: - """ - Function to convert latents to images - """ - - image = (latents / 2 + 0.5).clamp(0, 1) - image = image.detach().cpu().permute(0, 2, 3, 1).numpy() - images = (image * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images diff --git a/engine/models/LanguageModel.py b/engine/models/LanguageModel.py deleted file mode 100644 index 5d171cb..0000000 --- a/engine/models/LanguageModel.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations - -from typing import Callable, List, Union - -import torch -from torch.utils.hooks import RemovableHandle -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BatchEncoding, - PretrainedConfig, - PreTrainedModel, - PreTrainedTokenizer, -) - -from .AbstractModel import AbstractModel - - -class LanguageModel(AbstractModel): - def __init__(self, *args, **kwargs) -> None: - self.config: PretrainedConfig = None - self.tokenizer: PreTrainedTokenizer = None - self.meta_model: PreTrainedModel = None - self.local_model: PreTrainedModel = None - - super().__init__(*args, **kwargs) - - def _register_increment_hook(self, hook: Callable) -> RemovableHandle: - return self.local_model.register_forward_hook(hook) - - def _load_meta(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel: - self.config = AutoConfig.from_pretrained(repoid_or_path, *args, **kwargs) - - self.tokenizer = AutoTokenizer.from_pretrained( - repoid_or_path, config=self.config, padding_side="left" - ) - self.tokenizer.pad_token = self.tokenizer.eos_token - - return AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) - - def _load_local(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel: - return AutoModelForCausalLM.from_pretrained( - repoid_or_path, *args, config=self.config, **kwargs - ) - - def _prepare_inputs( - self, - inputs: Union[ - str, List[str], List[List[str]], List[int], List[List[int]], torch.Tensor - ], - ) -> BatchEncoding: - if isinstance(inputs, str) or ( - isinstance(inputs, list) and isinstance(inputs[0], int) - ): - inputs = [inputs] - - if isinstance(inputs, torch.Tensor) and inputs.ndim == 1: - inputs = inputs.unsqueeze(0) - - if not isinstance(inputs[0], str): - inputs = [self.tokenizer.decode(ids) for ids in inputs] - - return self.tokenizer(inputs, return_tensors="pt", padding=True) - - def _run_meta(self, inputs, *args, **kwargs) -> None: - inputs = self._prepare_inputs(inputs) - - self.meta_model(*args, **inputs.copy().to("meta"), **kwargs) - - return inputs["input_ids"] - - def _run_local(self, inputs, *args, **kwargs): - inputs = self._prepare_inputs(inputs) - - return self.local_model(*args, **inputs.to(self.local_model.device), **kwargs) - - def _generation(self, inputs, *args, **kwargs) -> None: - inputs = self._prepare_inputs(inputs) - - return self.local_model.generate( - *args, **inputs.to(self.local_model.device), **kwargs - ) diff --git a/engine/models/__init__.py b/engine/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engine/patching.py b/engine/patching.py deleted file mode 100644 index d68f208..0000000 --- a/engine/patching.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations - -import importlib -from typing import List - - -class Patch: - def __init__(self, obj, replacement) -> None: - self.obj = obj - self.replacement = replacement - - def patch(self): - module = importlib.import_module(self.obj.__module__) - - setattr(module, self.obj.__name__, self.replacement) - - def restore(self): - module = importlib.import_module(self.obj.__module__) - - setattr(module, self.obj.__name__, self.obj) - - -class Patcher: - def __init__(self, patches: List[Patch] = None) -> None: - self.patches = patches or [] - - def add(self, patch: Patch): - self.patches.append(patch) - - patch.patch() - - def __enter__(self) -> Patcher: - for patch in self.patches: - patch.patch() - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - for patch in self.patches: - patch.restore() - - -DEFAULT_PATCHER = Patcher() - -from functools import wraps - -import torch - - -def repeat_interleave_wrapper(fn): - @wraps(fn) - def repeat_interleave( - input: torch.Tensor, repeats: torch.LongTensor, dim=None, output_size=None - ): - if input.device.type == "meta": - if not isinstance(repeats, torch.Tensor): - repeats = torch.LongTensor([repeats]) - - if dim is None: - input = input.flatten() - dim = 0 - - if repeats.dim() == 0 or (repeats.dim() == 1 and repeats.size(0) == 1): - repeats = repeats.reshape([1]).expand([input.size(dim)]) - - new_dim_size = repeats.cumsum(0)[-1].item() - new_output_shape = list(input.shape) - new_output_shape[dim] = new_dim_size - - return torch.empty(new_output_shape, device="meta") - - else: - return fn(input, repeats, dim=dim, output_size=output_size) - - return repeat_interleave - - -DEFAULT_PATCHER.add( - Patch(torch.repeat_interleave, repeat_interleave_wrapper(torch.repeat_interleave)) -) - - -DEFAULT_PATCHER.__enter__() diff --git a/engine/pydantics/Config.py b/engine/pydantics/Config.py deleted file mode 100644 index a059445..0000000 --- a/engine/pydantics/Config.py +++ /dev/null @@ -1,9 +0,0 @@ -from pydantic import BaseModel - - -class ApiConfigModel(BaseModel): - HOST: str - - -class ConfigModel(BaseModel): - API: ApiConfigModel diff --git a/engine/pydantics/Request.py b/engine/pydantics/Request.py deleted file mode 100644 index d0ad012..0000000 --- a/engine/pydantics/Request.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -import pickle -from datetime import datetime -from typing import Dict, List, Type, Union - -from pydantic import ( - BaseModel, - ConfigDict, - field_serializer -) - -from ..fx.Graph import Graph -from .fx import NodeModel - - -class RequestModel(BaseModel): - model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True) - - args: List - kwargs: Dict - model_name: str - prompts: List[str] - intervention_graph: Union[Graph, bytes, Dict[str, NodeModel]] - # Edits - # altered - - id: str = None - recieved: datetime = None - blocking: bool = False - - @field_serializer("intervention_graph") - def intervention_graph_serialize(self, value: Union[str, Graph], _info) -> str: - if isinstance(value, Graph): - nodes = dict() - - for node in value.nodes.values(): - node = NodeModel.from_node(node) - nodes[node.name] = node - - value = nodes - - return pickle.dumps(value) - - def graph(self): - - graph = Graph(None) - - for node in self.intervention_graph.values(): - NodeModel.to_node(graph, self.intervention_graph, node) - - return graph diff --git a/engine/pydantics/Response.py b/engine/pydantics/Response.py deleted file mode 100644 index 0d26f38..0000000 --- a/engine/pydantics/Response.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -import logging -from datetime import datetime -from enum import Enum -from typing import Any, Dict - -from pydantic import BaseModel -from transformers.generation.utils import GenerateOutput - - -class JobStatus(Enum): - RECIEVED = "RECIEVED" - APPROVED = "APPROVED" - SUBMITTED = "SUBMITTED" - COMPLETED = "COMPLETED" - - ERROR = "ERROR" - - -class ResponseModel(BaseModel): - id: str - status: JobStatus - description: str - - output: Any = None - recieved: datetime = None - saves: Dict[str, Any] = None - blocking: bool = False - - def __str__(self) -> str: - return f"{self.id} - {self.status.name}: {self.description}" - - def log(self, logger: logging.Logger) -> ResponseModel: - if self.status == JobStatus.ERROR: - logger.error(str(self)) - else: - logger.info(str(self)) - - return self diff --git a/engine/pydantics/__init__.py b/engine/pydantics/__init__.py deleted file mode 100644 index c7eaf3b..0000000 --- a/engine/pydantics/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .Response import ResponseModel, JobStatus -from .Request import RequestModel -from .Config import ConfigModel -from .fx import NodeModel \ No newline at end of file diff --git a/engine/pydantics/fx.py b/engine/pydantics/fx.py deleted file mode 100644 index 040dccf..0000000 --- a/engine/pydantics/fx.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable, Dict, List, Union - -from pydantic import BaseModel - -from .. import util -from ..fx.Graph import Graph -from ..fx.Node import Node - - -class NodeModel(BaseModel): - class Reference(BaseModel): - name: str - - name: str - target: Union[Callable, str] - args: List[Any] - kwargs: Dict[str, Any] - - @staticmethod - def from_node(node: Node): - def _reference(node: Node): - return NodeModel.Reference(name=node.name) - - args = util.apply(node.args, _reference, Node) - kwargs = util.apply(node.kwargs, _reference, Node) - - return NodeModel(name=node.name, target=node.target, args=args, kwargs=kwargs) - - @staticmethod - def to_node(graph: Graph, nodes: Dict[str, NodeModel], node_model: NodeModel): - def _dereference(reference: NodeModel.Reference): - return NodeModel.to_node(graph, nodes, nodes[reference.name]) - - # Arguments might be interventions themselves so recurse. - args = util.apply(node_model.args, _dereference, NodeModel.Reference) - kwargs = util.apply(node_model.kwargs, _dereference, NodeModel.Reference) - - # Processing of args may have already created an Intervention for this node so just return it. - if node_model.name in graph.nodes: - return graph.nodes[node_model.name] - - graph.add( - graph=graph, - value=None, - target=node_model.target, - args=args, - kwargs=kwargs, - name=node_model.name, - ) diff --git a/engine/tests/__init__.py b/engine/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engine/tests/conftest.py b/engine/tests/conftest.py deleted file mode 100644 index 7e533e4..0000000 --- a/engine/tests/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -def pytest_addoption(parser): - parser.addoption("--device", action="store", default="cuda:0") - - -def pytest_generate_tests(metafunc): - # This is called for every test. Only get/set command line arguments - # if the argument is specified in the list of test "fixturenames". - option_value = metafunc.config.option.device - if 'device' in metafunc.fixturenames and option_value is not None: - metafunc.parametrize("device", [option_value], scope='module') \ No newline at end of file diff --git a/engine/tests/test_lm.py b/engine/tests/test_lm.py deleted file mode 100644 index fa59432..0000000 --- a/engine/tests/test_lm.py +++ /dev/null @@ -1,93 +0,0 @@ -import engine -import pytest -import torch - - -@pytest.fixture(scope="module") -def gpt2(device: str): - return engine.LanguageModel("gpt2", device_map=device) - - -@pytest.fixture -def MSG_prompt(): - return "Madison Square Garden is located in the city of" - - -def test_generation(gpt2: engine.LanguageModel, MSG_prompt: str): - with gpt2.generate(max_new_tokens=3) as generator: - with generator.invoke(MSG_prompt) as invoker: - pass - - output = gpt2.tokenizer.decode(generator.output[0]) - - assert output == "Madison Square Garden is located in the city of New York City" - - -def test_save(gpt2: engine.LanguageModel): - with gpt2.generate(max_new_tokens=1) as generator: - with generator.invoke("Hello world") as invoker: - hs = gpt2.transformer.h[-1].output[0].save() - - assert hs.node.done() - assert isinstance(hs.value, torch.Tensor) - assert hs.value.ndim == 3 - - -def test_set(gpt2: engine.LanguageModel): - with gpt2.generate(max_new_tokens=1) as generator: - with generator.invoke("Hello world") as invoker: - pre = gpt2.transformer.h[-1].output[0].save() - - gpt2.transformer.h[-1].output[0] = 0 - - post = gpt2.transformer.h[-1].output[0].save() - - output = gpt2.tokenizer.decode(generator.output[0]) - - assert not (pre.value == 0).all().item() - assert (post.value == 0).all().item() - assert output != "Madison Square Garden is located in the city of New" - - -def test_adhoc_module(gpt2: engine.LanguageModel): - with gpt2.generate() as generator: - with generator.invoke("The Eiffel Tower is in the city of") as invoker: - hidden_states = gpt2.transformer.h[-1].output[0] - hidden_states = gpt2.lm_head(gpt2.transformer.ln_f(hidden_states)) - tokens = torch.softmax(hidden_states, dim=2).argmax(dim=2).save() - - output = gpt2.tokenizer.decode(tokens.value[0]) - - assert output == "\n-el Tower is a the middle centre Paris" - - -def test_embeddings_set1(gpt2: engine.LanguageModel, MSG_prompt: str): - with gpt2.generate(max_new_tokens=3) as generator: - with generator.invoke(MSG_prompt) as invoker: - embeddings = gpt2.transformer.wte.output - - with generator.invoke("_ _ _ _ _ _ _ _ _") as invoker: - gpt2.transformer.wte.output = embeddings - - output1 = gpt2.tokenizer.decode(generator.output[0]) - output2 = gpt2.tokenizer.decode(generator.output[1]) - - assert output1 == "Madison Square Garden is located in the city of New York City" - assert output2 == "_ _ _ _ _ _ _ _ _ New York City" - - -def test_embeddings_set2(gpt2: engine.LanguageModel, MSG_prompt: str): - with gpt2.generate(max_new_tokens=3) as generator: - with generator.invoke(MSG_prompt) as invoker: - embeddings = gpt2.transformer.wte.output.save() - - output1 = gpt2.tokenizer.decode(generator.output[0]) - - with gpt2.generate(max_new_tokens=3) as generator: - with generator.invoke("_ _ _ _ _ _ _ _ _") as invoker: - gpt2.transformer.wte.output = embeddings.value - - output2 = gpt2.tokenizer.decode(generator.output[0]) - - assert output1 == "Madison Square Garden is located in the city of New York City" - assert output2 == "_ _ _ _ _ _ _ _ _ New York City" diff --git a/engine/toolbox/__init__.py b/engine/toolbox/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engine/toolbox/optim/__init__.py b/engine/toolbox/optim/__init__.py deleted file mode 100644 index d040e01..0000000 --- a/engine/toolbox/optim/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ - - -from abc import abstractmethod -from typing import Any - - -class Optimization: - - @abstractmethod - def parameters(self): - pass - - @abstractmethod - def __call__(self) -> Any: - pass \ No newline at end of file diff --git a/engine/toolbox/optim/lora.py b/engine/toolbox/optim/lora.py deleted file mode 100644 index 7cfdef8..0000000 --- a/engine/toolbox/optim/lora.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Any -import torch - -from ...Module import Module -from . import Optimization - - -class LORA(Optimization): - def __init__(self, module: Module, r: int) -> None: - self.module = module - self.r = r - - self.WA = torch.nn.Parameter(torch.empty(self.module.input_shape[0][-1], self.r), requires_grad=True) - self.WB = torch.nn.Parameter(torch.empty(self.r, self.module.output_shape[-1]), requires_grad=True) - - def __call__(self, alpha:float=1.0) -> Any: - - inp = self.module.input[0] - - self.module.output = (torch.matmul(torch.matmul(inp, self.WA), self.WB) + self.module.output) * alpha - - def parameters(self): - return [self.WA, self.WB] diff --git a/engine/toolbox/optim/softprompt.py b/engine/toolbox/optim/softprompt.py deleted file mode 100644 index 2a382fb..0000000 --- a/engine/toolbox/optim/softprompt.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any -import torch - -from ...Module import Module -from . import Optimization - - -class SoftPrompt(Optimization): - def __init__(self, module: Module, n: int) -> None: - self.module = module - self.n = n - - self.embedding = torch.nn.Parameter( - torch.zeros((self.n, self.module.embedding_dim)), requires_grad=True - ) - - def __call__(self) -> Any: - self.module.output = self.embedding[:] - - def parameters(self): - return [self.embedding] diff --git a/engine/util.py b/engine/util.py deleted file mode 100644 index 9d253ec..0000000 --- a/engine/util.py +++ /dev/null @@ -1,109 +0,0 @@ -import time -from functools import wraps -from typing import Any, Callable, Type, Union - -import torch - -Primative = Union[str, int, float, bool] -Value = Union[Primative, torch.Tensor] - - -def apply(data: Any, fn: Callable, cls: type): - if isinstance(data, cls): - return fn(data) - - if isinstance(data, list): - return [apply(_data, fn, cls) for _data in data] - - if isinstance(data, tuple): - return tuple([apply(_data, fn, cls) for _data in data]) - - if isinstance(data, dict): - return {key: apply(value, fn, cls) for key, value in data.items()} - - return data - - -def fetch_attr(object: object, target: str): - target_atoms = target.split(".") - for i, atom in enumerate(target_atoms): - object = getattr(object, atom) - return object - - -def wrap(object: object, wrapper: Type, *args, **kwargs): - if isinstance(object, wrapper): - return object - - object.__class__ = type(object.__class__.__name__, (wrapper, object.__class__), {}) - - wrapper.__init__(object, *args, **kwargs) - - return object - - -def timed(func, lggr): - """This decorator prints the execution time for the decorated function.""" - - @wraps(func) - def wrapper(*args, **kwargs): - start = time.time() - result = func(*args, **kwargs) - end = time.time() - lggr.debug(f"Method `{func.__qualname__}` ran in {round(end - start, 6)}s") - return result - - return wrapper - - -def cross_entropy_loss( - logits: torch.Tensor, - target_ids: torch.Tensor, - shift: bool = False, - avg_batch: bool = True, - avg_token: bool = True, -): - logits = logits.cpu() - target_ids = target_ids.cpu() - - if logits.ndim == 2: - logits = logits.unsqueeze(0) - - if target_ids.ndim == 1: - target_ids = target_ids.unsqueeze(0) - - assert logits.ndim == 3 - assert target_ids.ndim == 2 - assert logits.size(0) == target_ids.size(0) - assert logits.size(1) == target_ids.size(1) - - if shift: - logits = logits[:, :-1] - target_ids = target_ids[:, 1:] - - target_ids = target_ids.long() - - batch_losses = [] - - for batch_idx in range(len(logits)): - batch_loss = torch.nn.functional.cross_entropy( - logits[batch_idx], - target_ids[batch_idx], - reduction="mean" if avg_token else "none", - ) - batch_losses.append(batch_loss) - - batch_losses = torch.stack(batch_losses) - - if avg_batch: - batch_losses = batch_losses.mean(dim=0) - - return batch_losses - - -class WrapperModule(torch.nn.Module): - def forward(self, *args, **kwargs): - if len(args) == 1: - args = args[0] - - return args diff --git a/server/environment.yaml b/environment.yaml similarity index 100% rename from server/environment.yaml rename to environment.yaml diff --git a/examples/adhoc_module.py b/examples/adhoc_module.py deleted file mode 100644 index d24ae1c..0000000 --- a/examples/adhoc_module.py +++ /dev/null @@ -1,15 +0,0 @@ -from engine import LanguageModel -import torch - -model = LanguageModel("gpt2", device_map='cuda:0') - -with model.generate() as generator: - with generator.invoke('The Eiffel Tower is in the city of') as invoker: - - hidden_states = model.transformer.h[-1].output[0] - hidden_states = model.lm_head(model.transformer.ln_f(hidden_states)).save() - tokens = torch.softmax(hidden_states, dim=2).argmax(dim=2).save() - -print(hidden_states.value) -print(tokens.value) -print(model.tokenizer.decode(tokens.value[0])) diff --git a/examples/diffuser.py b/examples/diffuser.py deleted file mode 100644 index 2db0085..0000000 --- a/examples/diffuser.py +++ /dev/null @@ -1,13 +0,0 @@ -from engine import DiffuserModel - - -diffuser = DiffuserModel("CompVis/stable-diffusion-v1-4", device='cuda:0') - - -with diffuser.generate() as generator: - - with generator.invoke(["Blue elephant", "GREEN"]) as invoker: - - pass -diffuser.to_image(generator.output)[0].save('ummb.png') -diffuser.to_image(generator.output)[1].save('umm.png') \ No newline at end of file diff --git a/examples/embeddings.py b/examples/embeddings.py deleted file mode 100644 index 0232f48..0000000 --- a/examples/embeddings.py +++ /dev/null @@ -1,18 +0,0 @@ -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda:0') - -print(model) - -with model.generate(max_new_tokens=3) as generator: - - with generator.invoke("Madison square garden is located in the city of New") as invoker: - - embeddings = model.transformer.wte.output - - with generator.invoke("_ _ _ _ _ _ _ _ _ _") as invoker: - - model.transformer.wte.output = embeddings - -print(model.tokenizer.decode(generator.output[0])) -print(model.tokenizer.decode(generator.output[1])) \ No newline at end of file diff --git a/examples/lora.py b/examples/lora.py deleted file mode 100644 index 58528ab..0000000 --- a/examples/lora.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Any - -import torch -from engine import AbstractModel, LanguageModel, util -from engine.Module import Module -from engine.toolbox.optim.lora import LORA -from torch.utils.data import DataLoader, Dataset - -model = LanguageModel("gpt2", device_map="cuda:0") - -n_tokens = 10 -epochs = 1 -answer = "Paris" -answer_tokens = model.tokenizer(answer) -answer_token = answer_tokens["input_ids"][0] - -lora = LORA(model.transformer.h[0].mlp, 10) - -optimizer = torch.optim.AdamW(lora.parameters(), lr=.1) -dataset = [[" ".join(["_"] * n_tokens), answer_token]] * 100 -dataloader = DataLoader(dataset, batch_size=10) - - -lossfn = util.cross_entropy_loss - -for epoch in range(epochs): - print(epoch) - - for i, (inputs, targets) in enumerate(dataloader): - print(f" {i}") - - optimizer.zero_grad() - - with model.forward(inputs, inference=False) as runner: - - - lora() - - logits = model.lm_head.output.save() - - print(lora.WA) - loss = lossfn(logits.value[:, -1], targets) - print(loss) - - loss.backward() - - optimizer.step() - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - pass - -print(model.tokenizer.decode(generator.output[0])) - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - lora() - -print(model.tokenizer.decode(generator.output[0])) diff --git a/examples/modulize.py b/examples/modulize.py deleted file mode 100644 index 6e90d3d..0000000 --- a/examples/modulize.py +++ /dev/null @@ -1,16 +0,0 @@ -from engine import LanguageModel - -model = LanguageModel('gpt2', device_map='cuda:0') - -print(model.transformer.h[1].attn.graph) - -model.modulize(model.transformer.h[1].attn, 'softmax_0', 'attention_probs') - -with model.generate(max_new_tokens=3) as generator: - - with generator.invoke('Hello world') as invoker: - - attention_probs = model.transformer.h[1].attn.attention_probs.output.save() - -print(attention_probs.value) - diff --git a/examples/multitoken.py b/examples/multitoken.py deleted file mode 100644 index 7e68888..0000000 --- a/examples/multitoken.py +++ /dev/null @@ -1,51 +0,0 @@ -from engine import LanguageModel - -model = LanguageModel("gpt2", device_map='cuda:0') - - -def get_scores(): - hs = model.transformer.h[-1].output[0] - - return model.lm_head(model.transformer.ln_f(hs)) - - -def decode(scores): - scores = scores.argmax(dim=2)[0, -1] - return model.tokenizer.decode(scores) - - -with model.generate( - max_new_tokens=3, - return_dict_in_generate=True, - output_scores=True, -) as generator: - with generator.invoke( - "Madison square garden is located in the city of New" - ) as invoker: - tokenized = invoker.tokens - - # Reference the hidden states of the last layer for each token of the nine tokens (shape: (1,9,768)) - # Apply lm_head (decode into vocabulary space) and copy and return value (shape: (1,9,50257)) - logits1 = get_scores().save() - - # Denote that you are generating a token and subsequent interventions will apply to that generation - # and not the previous ones. - invoker.next() - - # Here the shape of the hidden states is (1, 1, 768) as there is just the one token - # Get its hidden states of the last layer decoded as well - logits2 = get_scores().save() - - # And again.... - invoker.next() - - logits3 = get_scores().save() - - -pred1 = decode(logits1.value) -pred2 = decode(logits2.value) -pred3 = decode(logits3.value) - -print(pred1) -print(pred2) -print(pred3) diff --git a/examples/optim.py b/examples/optim.py deleted file mode 100644 index b942304..0000000 --- a/examples/optim.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any - -import torch -from engine import AbstractModel, LanguageModel, util -from engine.Module import Module -from torch.utils.data import DataLoader, Dataset - -model = LanguageModel("gpt2", device_map="cuda:0") - -n_tokens = 10 -epochs = 1 -answer = "Paris" -answer_tokens = model.tokenizer(answer) -answer_token = answer_tokens["input_ids"][0] - - -class SoftPrompt: - def __init__(self, module: Module, n: int) -> None: - self.module = module - self.n = n - - self.embedding = torch.nn.Parameter( - torch.zeros((self.n, self.module.embedding_dim)), requires_grad=True - ) - - def __call__(self) -> Any: - self.module.output = self.embedding[:] - - def parameters(self): - return [self.embedding] - - -sp = SoftPrompt(model.transformer.wte, n_tokens) - -optimizer = torch.optim.AdamW(sp.parameters()) -dataset = [[" ".join(["_"] * n_tokens), answer_token]] * 100 -dataloader = DataLoader(dataset, batch_size=10) - - -def decode(hs): - return torch.log_softmax(model.lm_head(model.transformer.ln_f(hs)), dim=-1) - -with model.generate(max_new_tokens=2) as generator: - - with generator.invoke("Madison Square Garden is located in New") as invoker: - - hs = model.transformer.h[1].output[0].t[-1].save() - - invoker.next() - - target = decode(model.transformer.h[-1].output[0]).save() - -target = target.value -hs = hs.value - -lossfn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) - -for epoch in range(epochs): - print(epoch) - - for i, (inputs, targets) in enumerate(dataloader): - print(f" {i}") - - optimizer.zero_grad() - - with model.forward(inputs, inference=False) as runner: - - sp() - - model.transformer.h[1].output[0].t[-1] = hs - - pred = decode(model.transformer.h[-1].output[0]).save() - - - - loss = lossfn(pred.value, target) - loss.backward() - - optimizer.step() - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - pass - -print(model.tokenizer.decode(generator.output[0])) - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - sp() - -print(model.tokenizer.decode(generator.output[0])) diff --git a/examples/optim2.py b/examples/optim2.py deleted file mode 100644 index b942304..0000000 --- a/examples/optim2.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any - -import torch -from engine import AbstractModel, LanguageModel, util -from engine.Module import Module -from torch.utils.data import DataLoader, Dataset - -model = LanguageModel("gpt2", device_map="cuda:0") - -n_tokens = 10 -epochs = 1 -answer = "Paris" -answer_tokens = model.tokenizer(answer) -answer_token = answer_tokens["input_ids"][0] - - -class SoftPrompt: - def __init__(self, module: Module, n: int) -> None: - self.module = module - self.n = n - - self.embedding = torch.nn.Parameter( - torch.zeros((self.n, self.module.embedding_dim)), requires_grad=True - ) - - def __call__(self) -> Any: - self.module.output = self.embedding[:] - - def parameters(self): - return [self.embedding] - - -sp = SoftPrompt(model.transformer.wte, n_tokens) - -optimizer = torch.optim.AdamW(sp.parameters()) -dataset = [[" ".join(["_"] * n_tokens), answer_token]] * 100 -dataloader = DataLoader(dataset, batch_size=10) - - -def decode(hs): - return torch.log_softmax(model.lm_head(model.transformer.ln_f(hs)), dim=-1) - -with model.generate(max_new_tokens=2) as generator: - - with generator.invoke("Madison Square Garden is located in New") as invoker: - - hs = model.transformer.h[1].output[0].t[-1].save() - - invoker.next() - - target = decode(model.transformer.h[-1].output[0]).save() - -target = target.value -hs = hs.value - -lossfn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) - -for epoch in range(epochs): - print(epoch) - - for i, (inputs, targets) in enumerate(dataloader): - print(f" {i}") - - optimizer.zero_grad() - - with model.forward(inputs, inference=False) as runner: - - sp() - - model.transformer.h[1].output[0].t[-1] = hs - - pred = decode(model.transformer.h[-1].output[0]).save() - - - - loss = lossfn(pred.value, target) - loss.backward() - - optimizer.step() - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - pass - -print(model.tokenizer.decode(generator.output[0])) - - -with model.generate() as generator: - with generator.invoke(dataset[0][0]) as invoker: - sp() - -print(model.tokenizer.decode(generator.output[0])) diff --git a/examples/test.py b/examples/test.py deleted file mode 100644 index e6ea2d1..0000000 --- a/examples/test.py +++ /dev/null @@ -1,45 +0,0 @@ -# The library is called engine -from engine import LanguageModel -import torch -# Get model wrapper for any model you can get with AutoConfig.from_pretrained(model_name) -model = LanguageModel('gpt2',device_map='cuda:0') - -# Prints normal module tree to show access tree for modules -print(model) - -# Invoke using a prompt -with model.generate(max_new_tokens=3) as generator: - with generator.invoke('Hello world ') as invoker: - - # See the input prompt seperated into token strings - tokenized = invoker.tokens - - # Use normal module access and .output to get output activations. - # Then save the activations at this point in the execution tree - # not only to retrieve later on (by calling mlp0.value after the model has been ran), - # but if this value changes throughout interventions and you want the value before those alterations. - - mlp0 = model.transformer.h[0].mlp.output.save() - - # Use .token[idx] or .t[idx] to index by token - mlp0_t1_t = model.transformer.h[0].mlp.output.t[0].save() - - mmlp0 = model.transformer.h[0].mlp.output - mmlp1 = model.transformer.h[1].mlp.output - # Addition works like you normally would either with tensors or primatives ( will add other operations later) - noise = (0.001**0.5)*torch.randn(mmlp1.t[1].shape) - mmlp1 = mmlp1.t[1] + noise - - mmlp2_before = model.transformer.h[2].mlp.output.save() - - # Easily set the output of a module to whatever you want - model.transformer.h[2].mlp.output = mmlp0 + mmlp1 - - # See the before and after the intervation - mmlp2_after = model.transformer.h[2].mlp.output.save() - - with generator.invoke('Goodbye world') as invoker: - - # Operations work cross-prompt! - model.transformer.h[1].mlp.output = mmlp0 - diff --git a/examples/test_server.py b/examples/test_server.py deleted file mode 100644 index 41de0f6..0000000 --- a/examples/test_server.py +++ /dev/null @@ -1,12 +0,0 @@ -from engine import LanguageModel - -model = LanguageModel("gpt2") - -print(model) - -with model.generate(server=True) as generator: - with generator.invoke("Hello world") as invoker: - hiddenstates = model.transformer.h[2].output.save() - - -print(hiddenstates.value) diff --git a/examples/test_server_llama.py b/examples/test_server_llama.py deleted file mode 100644 index d7c6a4a..0000000 --- a/examples/test_server_llama.py +++ /dev/null @@ -1,12 +0,0 @@ -from engine import LanguageModel - -model = LanguageModel('decapoda-research/llama-65b-hf') - -print(model) - -with model.generate(server=True, max_new_tokens=1, return_dict_in_generate=True, output_scores=True) as generator: - with generator.invoke('Hello world') as invoker: - - hiddenstates = model.model.layers[0].output.save() - -print(hiddenstates.value) diff --git a/examples/tl/.ipynb_checkpoints/attention-checkpoint.ipynb b/examples/tl/.ipynb_checkpoints/attention-checkpoint.ipynb deleted file mode 100644 index c184f42..0000000 --- a/examples/tl/.ipynb_checkpoints/attention-checkpoint.ipynb +++ /dev/null @@ -1,104 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using renderer: notebook_connected\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import circuitsvis as cv\n", - "# Testing that the library works\n", - "cv.examples.hello(\"Neel\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "vis" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/tl/Main_Demo.ipynb b/examples/tl/Main_Demo.ipynb deleted file mode 100644 index c81bee1..0000000 --- a/examples/tl/Main_Demo.ipynb +++ /dev/null @@ -1,3212 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "9-OfK57xchiW" - }, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ONjp2zU_chiY" - }, - "source": [ - "# Transformer Lens Main Demo Notebook\n", - "\n", - "To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n", - "\n", - "This is a reference notebook covering the main features of the [TransformerLens](https://github.com/neelnanda-io/TransformerLens) library for mechanistic interpretability. See [Callum McDougall's tutorial](https://transformerlens-intro.streamlit.app/TransformerLens_&_induction_circuits) for a more structured and gentler introduction to the library" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ad0azi0PchiZ" - }, - "source": [ - "**Tips for reading this Colab:**\n", - "* You can run all this code for yourself!\n", - "* The graphs are interactive!\n", - "* Use the table of contents pane in the sidebar to navigate\n", - "* Collapse irrelevant sections with the dropdown arrows\n", - "* Search the page using the search in the sidebar, not CTRL+F" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MGG6FbJzchiZ" - }, - "source": [ - "# Setup\n", - "(No need to read)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "2BksZggmchiZ", - "outputId": "ee0c2450-ddec-43f4-c59d-751a6eae2a4f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_13657/2955938409.py:20: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"load_ext autoreload\")\n", - "/tmp/ipykernel_13657/2955938409.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"autoreload 2\")\n" - ] - } - ], - "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "DEVELOPMENT_MODE = False\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "\n", - " %pip install git+https://github.com/JadenFiotto-Kaufman/ndif\n", - " %pip install circuitsvis\n", - "\n", - "\n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " from IPython import get_ipython\n", - "\n", - " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "tbRafBLnchib", - "outputId": "536f37aa-3118-4f98-a6cd-31d434915afa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using renderer: colab\n" - ] - } - ], - "source": [ - "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", - "import plotly.io as pio\n", - "if IN_COLAB or not DEVELOPMENT_MODE:\n", - " pio.renderers.default = \"colab\"\n", - "else:\n", - " pio.renderers.default = \"notebook_connected\"\n", - "print(f\"Using renderer: {pio.renderers.default}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "fzcdMBqxchic", - "outputId": "1fad4271-2478-48c0-a25a-6287911b2073" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import circuitsvis as cv\n", - "# Testing that the library works\n", - "cv.examples.hello(\"Jaden\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "VFPohiZEchic" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/jadenfk/miniconda3/envs/ndif/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:\n", - "\n", - "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - "\n" - ] - } - ], - "source": [ - "import engine\n", - "import torch\n", - "import plotly.express as px\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5l6xsWBFchid" - }, - "source": [ - "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "3IVs7Jhjchid", - "outputId": "f3e856e2-5642-4fad-b1f9-f8f51c97bde2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.set_grad_enabled(False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TJ8jv_pjchid" - }, - "source": [ - "Plotting helper functions:" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": { - "id": "s9a1yjoAchid" - }, - "outputs": [], - "source": [ - "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", - " px.imshow(tensor, color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", - "\n", - "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", - " px.line(tensor, labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", - "\n", - "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n", - "\n", - " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSO4RkLwchie" - }, - "source": [ - "# Introduction" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "njTGUhgochie" - }, - "source": [ - "This is a demo notebook for [TransformerLens](https://github.com/neelnanda-io/TransformerLens), **a library I ([Neel Nanda](neelnanda.io)) wrote for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models.** The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **If you want to skill up, check out [my guide to getting started](https://neelnanda.io/getting-started), and if you want to jump into an open problem check out my sequence [200 Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems).**\n", - "\n", - "I wrote this library because after I left the Anthropic interpretability team and started doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! The core features were heavily inspired by [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for accelerating exploratory research!\n", - "\n", - "The core design principle I've followed is to enable exploratory analysis - one of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state. This notebook demonstrates how the library works and how to use it, but if you want to see how well it works for exploratory research, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UPQJt_NHchie" - }, - "source": [ - "## Loading and Running Models\n", - "\n", - "TransformerLens comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. For this demo notebook we'll look at GPT-2 Small, an 80M parameter model, see the Available Models section for info on the rest." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "bxBfAVwmchie" - }, - "outputs": [], - "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "O316vs8Bchie", - "outputId": "d3ed33cb-b4de-42c8-b98a-f6ea255d881e" - }, - "outputs": [], - "source": [ - "model = engine.LanguageModel('gpt2', device_map='cuda:0')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OVRRpdynchie" - }, - "source": [ - "To try the model out, let's find the loss on this text! Models can be run on a single string or a tensor of tokens (shape: [batch, position], all integers), and the possible return types are:\n", - "* \"logits\" (shape [batch, position, d_vocab], floats),\n", - "* \"loss\" (the cross-entropy loss when predicting the next token),\n", - "* \"both\" (a tuple of (logits, loss))\n", - "* None (run the model, but don't calculate the logits - this is faster when we only want to use intermediate activations)" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "metadata": { - "id": "5LoEbLruchif", - "outputId": "592777e2-ca5b-45ee-ceec-5cdfc188a4e2" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model loss: 4.137\n" - ] - } - ], - "source": [ - "model_description_text = \"\"\"## Loading Models\n", - "\n", - "HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.\n", - "\n", - "For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!\"\"\"\n", - "\n", - "with model.generate(max_new_tokens=1) as generator:\n", - "\n", - " with generator.invoke(model_description_text) as invoker:\n", - "\n", - " token_ids = invoker.ids\n", - "\n", - " logits = model.lm_head.output.save()\n", - "\n", - "from engine.util import cross_entropy_loss\n", - "\n", - "loss = cross_entropy_loss(logits.value, token_ids, shift=True)\n", - "\n", - "print(f\"Model loss: {loss.item():.3f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nmVJ-XWwchif" - }, - "source": [ - "## Caching all Activations\n", - "\n", - "The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model. This can be done with `logits, cache = model.run_with_cache(tokens)`. Let's try this out on the first line of the abstract of the GPT-2 paper.\n", - "\n", - "
On `remove_batch_dim`\n", - "\n", - "Every activation inside the model begins with a batch dimension. Here, because we only entered a single batch dimension, that dimension is always length 1 and kinda annoying, so passing in the `remove_batch_dim=True` keyword removes it. `gpt2_cache_no_batch_dim = gpt2_cache.remove_batch_dim()` would have achieved the same effect.\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "U7W7HG7Cchif", - "outputId": "162203bb-762e-498c-deb1-081830e7f65e" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - } - ], - "source": [ - "gpt2_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n", - "\n", - "with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n", - "\n", - " with generator.invoke(gpt2_text, output_attentions=True) as invoker:\n", - "\n", - " gpt2_cache = invoker.save()\n", - "\n", - " gpt2_attn = model.transformer.h[0].attn.output[2][0].save()\n", - "\n", - " gpt2_str_tokens = invoker.tokens\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5xLrULRmchif" - }, - "source": [ - "Let's visualize the attention pattern of all the heads in layer 0, using [Alan Cooney's CircuitsVis library](https://github.com/alan-cooney/CircuitsVis) (based on [Anthropic's PySvelte library](https://github.com/anthropics/PySvelte)).\n", - "\n", - "We look this the attention pattern in `gpt2_cache`, an `ActivationCache` object, by entering in the name of the activation, followed by the layer index (here, the activation is called \"attn\" and the layer index is 0). This has shape [head_index, destination_position, source_position], and we use the `model.to_str_tokens` method to convert the text to a list of tokens as strings, since there is an attention weight between each pair of tokens.\n", - "\n", - "This visualization is interactive! Try hovering over a token or head, and click to lock. The grid on the top left and for each head is the attention pattern as a destination position by source position grid. It's lower triangular because GPT-2 has **causal attention**, attention can only look backwards, so information can only move forwards in the network.\n", - "\n", - "See the ActivationCache section for more on what `gpt2_cache` can do." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "NEI_wjbrchif", - "outputId": "f773e3a4-8f55-41fa-8d2b-a97e1434586d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 32, 32])\n" - ] - } - ], - "source": [ - "attention_pattern = gpt2_cache['transformer.h.0.attn'].value[2][0]\n", - "print(attention_pattern.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 32, 32])\n" - ] - } - ], - "source": [ - "attention_pattern = gpt2_attn.value\n", - "print(attention_pattern.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "iIfy03t8chif", - "outputId": "7400cbf1-3741-4a44-e097-29b53e65e5fe" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Layer 0 Head Attention Patterns:\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(\"Layer 0 Head Attention Patterns:\")\n", - "cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T917LM02chig" - }, - "source": [ - "## Hooks: Intervening on Activations" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kiLyitPYchig" - }, - "source": [ - "One of the great things about interpreting neural networks is that we have *full control* over our system. From a computational perspective, we know exactly what operations are going on inside (even if we don't know what they mean!). And we can make precise, surgical edits and see how the model's behaviour and other internals change. This is an extremely powerful tool, because it can let us eg set up careful counterfactuals and causal intervention to easily understand model behaviour.\n", - "\n", - "Accordingly, being able to do this is a pretty core operation, and this is one of the main things TransformerLens supports! The key feature here is **hook points**. Every activation inside the transformer is surrounded by a hook point, which allows us to edit or intervene on it.\n", - "\n", - "We do this by adding a **hook function** to that activation. The hook function maps `current_activation_value, hook_point` to `new_activation_value`. As the model is run, it computes that activation as normal, and then the hook function is applied to compute a replacement, and that is substituted in for the activation. The hook function can be an arbitrary Python function, so long as it returns a tensor of the correct shape.\n", - "\n", - "
Relationship to PyTorch hooks\n", - "\n", - "[PyTorch hooks](https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/) are a great and underrated, yet incredibly janky, feature. They can act on a layer, and edit the input or output of that layer, or the gradient when applying autodiff. The key difference is that **Hook points** act on *activations* not layers. This means that you can intervene within a layer on each activation, and don't need to care about the precise layer structure of the transformer. And it's immediately clear exactly how the hook's effect is applied. This adjustment was shamelessly inspired by [Garcon's use of ProbePoints](https://transformer-circuits.pub/2021/garcon/index.html).\n", - "\n", - "They also come with a range of other quality of life improvements, like the model having a `model.reset_hooks()` method to remove all hooks, or helper methods to temporarily add hooks for a single forward pass - it is *incredibly* easy to shoot yourself in the foot with standard PyTorch hooks!\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Niu5-Pnkchig" - }, - "source": [ - "As a basic example, let's [ablate](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=fh-HJyz1CgUVrXuoiban6bYx) head 7 in layer 0 on the text above.\n", - "\n", - "We define a `head_ablation_hook` function. This takes the value tensor for attention layer 0, and sets the component with `head_index==7` to zero and returns it (Note - we return by convention, but since we're editing the activation in-place, we don't strictly *need* to).\n", - "\n", - "We then use the `run_with_hooks` helper function to run the model and *temporarily* add in the hook for just this run. We enter in the hook as a tuple of the activation name (also the hook point name - found with `utils.get_act_name`) and the hook function." - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "cH83nFR5chig", - "outputId": "755034b6-d29e-4a40-bf4f-b229b288b4fc" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Original Loss: 4.019\n", - "Ablated Loss: 4.291\n" - ] - } - ], - "source": [ - "layer_to_ablate = 0\n", - "head_index_to_ablate = 7\n", - "\n", - "\n", - "with model.generate(device_map='cuda:0', max_new_tokens=1) as generator:\n", - "\n", - " with generator.invoke(gpt2_text) as invoker:\n", - "\n", - " token_ids = invoker.ids\n", - "\n", - " clean_logits = model.lm_head.output.save()\n", - "\n", - " with generator.invoke(gpt2_text) as invoker:\n", - "\n", - " attn_module = model.transformer.h[layer_to_ablate].attn\n", - "\n", - " split_size = attn_module.split_size\n", - " head_dim = attn_module.head_dim\n", - "\n", - " qkv = attn_module.c_attn.output.split(split_size, dim=2)\n", - " v = qkv[2]\n", - " v_heads = v.split(head_dim, dim=2)\n", - " v_heads[head_index_to_ablate] = 0\n", - "\n", - " ablt_logits = model.lm_head.output.save()\n", - "\n", - "clean_logits = clean_logits.value\n", - "ablt_logits = ablt_logits.value\n", - "\n", - "clean_loss = cross_entropy_loss(clean_logits, token_ids, shift=True)\n", - "ablt_loss = cross_entropy_loss(ablt_logits, token_ids, shift=True)\n", - "\n", - "print(f\"Original Loss: {clean_loss.item():.3f}\")\n", - "print(f\"Ablated Loss: {ablt_loss.item():.3f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Bz1S4rtVchig" - }, - "source": [ - "**Gotcha:** Hooks are global state - they're added in as part of the model, and stay there until removed. `run_with_hooks` tries to create an abstraction where these are local state, by removing all hooks at the end of the function. But you can easily shoot yourself in the foot if there's, eg, an error in one of your hooks so the function never finishes. If you start getting bugs, try `model.reset_hooks()` to clean things up. Further, if you *do* add hooks of your own that you want to keep, which you can do with `add_perma_hook` on the relevant HookPoint" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lX0W0Tbgchig" - }, - "source": [ - "### Activation Patching on the Indirect Object Identification Task" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8CWZ2RPcchig" - }, - "source": [ - "For a somewhat more involved example, let's use hooks to apply **[activation patching](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)** on the **[Indirect Object Identification](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=iWsV3s5Kdd2ca3zNgXr5UPHa)** (IOI) task.\n", - "\n", - "The IOI task is the task of identifying that a sentence like \"After John and Mary went to the store, Mary gave a bottle of milk to\" continues with \" John\" rather than \" Mary\" (ie, finding the indirect object), and Redwood Research have [an excellent paper studying the underlying circuit in GPT-2 Small](https://arxiv.org/abs/2211.00593).\n", - "\n", - "**[Activation patching](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)** is a technique from [Kevin Meng and David Bau's excellent ROME paper](https://rome.baulab.info/). The goal is to identify which model activations are important for completing a task. We do this by setting up a **clean prompt** and a **corrupted prompt** and a **metric** for performance on the task. We then pick a specific model activation, run the model on the corrupted prompt, but then *intervene* on that activation and patch in its value when run on the clean prompt. We then apply the metric, and see how much this patch has recovered the clean performance.\n", - "(See [a more detailed demonstration of activation patching here](https://colab.research.google.com/github.com/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-Wmn3EXkchig" - }, - "source": [ - "Here, our clean prompt is \"After John and Mary went to the store, **Mary** gave a bottle of milk to\", our corrupted prompt is \"After John and Mary went to the store, **John** gave a bottle of milk to\", and our metric is the difference between the correct logit ( John) and the incorrect logit ( Mary) on the final token.\n", - "\n", - "We see that the logit difference is significantly positive on the clean prompt, and significantly negative on the corrupted prompt, showing that the model is capable of doing the task!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ilJVVnVschih" - }, - "source": [ - "We now setup the hook function to do activation patching. Here, we'll patch in the [residual stream](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=DHp9vZ0h9lA9OCrzG2Y3rrzH) at the start of a specific layer and at a specific position. This will let us see how much the model is using the residual stream at that layer and position to represent the key information for the task.\n", - "\n", - "We want to iterate over all layers and positions, so we write the hook to take in an position parameter. Hook functions must have the input signature (activation, hook), but we can use `functools.partial` to set the position parameter before passing it to `run_with_hooks`" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "colab": { - "referenced_widgets": [ - "b93d6b6e8c37495f84b7a00f2caf81c3" - ] - }, - "id": "slK2OwJ1chih", - "outputId": "df1ca47b-8629-4d94-ed73-b8c5a88b48c6" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Clean logit difference: 4.124\n", - "Corrupted logit difference: -2.272\n" - ] - } - ], - "source": [ - "clean_prompt = \"After John and Mary went to the store, Mary gave a bottle of milk to\"\n", - "corrupted_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", - "\n", - "correct_index = model.tokenizer(\" John\")['input_ids'][0]\n", - "incorrect_index = model.tokenizer(\" Mary\")['input_ids'][0]\n", - "\n", - "with model.generate(max_new_tokens=1) as generator:\n", - "\n", - " with generator.invoke(clean_prompt) as invoker:\n", - "\n", - " clean_tokens = invoker.tokens\n", - "\n", - " clean_hs = [model.transformer.h[layer_idx].output[0] for layer_idx in range(len(model.transformer.h))]\n", - "\n", - " clean_logits = model.lm_head.output\n", - "\n", - " clean_logit_diff = (clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]).save()\n", - "\n", - " with generator.invoke(corrupted_prompt) as invoker:\n", - "\n", - " corrupted_tokens = invoker.tokens\n", - "\n", - " corrupted_logits = model.lm_head.output\n", - "\n", - " corrupted_logit_diff = (corrupted_logits[0, -1, correct_index] - corrupted_logits[0, -1, incorrect_index]).save()\n", - "\n", - " ioi_patching_results = []\n", - "\n", - " for layer_idx in range(len(model.transformer.h)):\n", - "\n", - " _ioi_patching_results = []\n", - "\n", - " for token_idx in range(len(clean_tokens)):\n", - "\n", - " with generator.invoke(corrupted_prompt) as invoker:\n", - "\n", - " model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx].t[token_idx]\n", - "\n", - " patched_logits = model.lm_head.output\n", - "\n", - " patched_logit_diff = patched_logits[0, -1, correct_index] - patched_logits[0, -1, incorrect_index]\n", - "\n", - " patched_result = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)\n", - "\n", - " _ioi_patching_results.append(patched_result.save())\n", - "\n", - " ioi_patching_results.append(_ioi_patching_results)\n", - "\n", - "print(f\"Clean logit difference: {clean_logit_diff.value:.3f}\")\n", - "print(f\"Corrupted logit difference: {corrupted_logit_diff.value:.3f}\")\n", - "\n", - "from engine.fx.Proxy import Proxy\n", - "from engine import util\n", - "\n", - "ioi_patching_results = util.apply(ioi_patching_results, lambda x : x.value.item(), Proxy)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uPCoNwrschih" - }, - "source": [ - "We can now visualize the results, and see that this computation is extremely localised within the model. Initially, the second subject (Mary) token is all that matters (naturally, as it's the only different token), and all relevant information remains here until heads in layer 7 and 8 move this to the final token where it's used to predict the indirect object.\n", - "(Note - the heads are in layer 7 and 8, not 8 and 9, because we patched in the residual stream at the *start* of each layer)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "id": "Yp5kK92vchih", - "outputId": "491359e7-c3a0-48f7-c691-6dcdf02628a1" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "%matplotlib inline\n", - "# Add the index to the end of the label, because plotly doesn't like duplicate labels\n", - "token_labels = [f\"{token}_{index}\" for index, token in enumerate(clean_tokens)]\n", - "imshow(ioi_patching_results, x=token_labels, xaxis=\"Position\", yaxis=\"Layer\", title=\"Normalized Logit Difference After Patching Residual Stream on the IOI Task\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "r0eQFuycchih" - }, - "source": [ - "## Hooks: Accessing Activations" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T_8-Bq3qchil" - }, - "source": [ - "Hooks can also be used to just **access** an activation - to run some function using that activation value, *without* changing the activation value. This can be achieved by just having the hook return nothing, and not editing the activation in place.\n", - "\n", - "This is useful for eg extracting activations for a specific task, or for doing some long-running calculation across many inputs, eg finding the text that most activates a specific neuron. (Note - everything this can do *could* be done with `run_with_cache` and post-processing, but this workflow can be more intuitive and memory efficient.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZHQLnzHichil" - }, - "source": [ - "To demonstrate this, let's look for **[induction heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)** in GPT-2 Small.\n", - "\n", - "Induction circuits are a very important circuit in generative language models, which are used to detect and continue repeated subsequences. They consist of two heads in separate layers that compose together, a **previous token head** which always attends to the previous token, and an **induction head** which attends to the token *after* an earlier copy of the current token.\n", - "\n", - "To see why this is important, let's say that the model is trying to predict the next token in a news article about Michael Jordan. The token \" Michael\", in general, could be followed by many surnames. But an induction head will look from that occurence of \" Michael\" to the token after previous occurences of \" Michael\", ie \" Jordan\" and can confidently predict that that will come next." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_9TxWCtwchil" - }, - "source": [ - "An interesting fact about induction heads is that they generalise to arbitrary sequences of repeated tokens. We can see this by generating sequences of 50 random tokens, repeated twice, and plotting the average loss at predicting the next token, by position. We see that the model goes from terrible to very good at the halfway point." - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "metadata": { - "id": "QMAYrNd3chil", - "outputId": "2ea6db74-d983-4072-ba97-d1c25f07b454" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[' seeking', ' restaur', 'ixture', ' issued', 'ラ', 'ik', ' NHL', 'ifying', ' battle', 'ogn', ' Bell', ' ideal', ' nature', 'lla', ' cance', 'res', 'sel', ' incorpor', ' carbon', 'roid', ' auto', ' European', ' NHL', ' arch', ' belief', 'gery', ' substantial', ' key', ' Tok', 'bec', 'esh', ' Association', \"'ll\", 'iz', 'ere', 'fully', ' miles', ' Den', ' Ear', 'ー', 'used', 'EP', 'aste', ' assistant', ' drug', 'ornia', ' behalf', ' explore', ' talent', 'uild', ' est', ' Low', ' seeking', ' restaur', 'ixture', ' issued', 'ラ', 'ik', ' NHL', 'ifying', ' battle', 'ogn', ' Bell', ' ideal', ' nature', 'lla', ' cance', 'res', 'sel', ' incorpor', ' carbon', 'roid', ' auto', ' European', ' NHL', ' arch', ' belief', 'gery', ' substantial', ' key', ' Tok', 'bec', 'esh', ' Association', \"'ll\", 'iz', 'ere', 'fully', ' miles', ' Den', ' Ear', 'ー', 'used', 'EP', 'aste', ' assistant', ' drug', 'ornia', ' behalf', ' explore', ' talent', 'uild', ' est', ' Low'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' complex', 'store', 'yright', ' operate', ' Captain', 'hest', ' correctly', ' Queen', ' enforce', ' gang', 'phone', ' Asian', ' Israel', ' ', '<|endoftext|>', ' Key', ' Daily', ' Pal', ' surprised', ' below', ' recording', ' camer', ' discussion', ' often', ' shall', 'iled', 'roud', ' Bry', 'osen', ' adult', ' •', ' investigate', ' external', 'ournament', '()', ' install', 'izations', 'aving', ' intent', ' Cop', 'hen', 'sell', 'ig', ' intern', ' colors', ' Moore', ' contain', ' added', 'iled', ' topic', ' kind', ' blog', ' sequ', ' Protection', ' comprom', ' lic', 'aze', ' contem', ' relevant', ' expectations', 'uthor', ' alleged', ' ap', ' try', 'iah', ' producer', ' Key', ' Daily', ' Pal', ' surprised', ' below', ' recording', ' camer', ' discussion', ' often', ' shall', 'iled', 'roud', ' Bry', 'osen', ' adult', ' •', ' investigate', ' external', 'ournament', '()', ' install', 'izations', 'aving', ' intent', ' Cop', 'hen', 'sell', 'ig', ' intern', ' colors', ' Moore', ' contain', ' added', 'iled', ' topic', ' kind', ' blog', ' sequ', ' Protection', ' comprom', ' lic', 'aze', ' contem', ' relevant', ' expectations', 'uthor', ' alleged', ' ap', ' try', 'iah', ' producer'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' affairs', 'Pr', 'AM', ' Saf', 'iding', 'uable', 'irty', 'head', 'uto', ' obst', ' swe', ' factors', 'ORT', 'aint', ' tried', ' marketing', ' feedback', ' premium', ' Code', ' depression', ' kept', ' ability', ' inject', ' grade', ' Max', 'Then', ' honor', ' samples', ' computers', ' Hel', ' figure', ' intelligence', ' preced', ' although', ' repeat', ' sign', 'ked', ' month', ' surgery', 'view', ' therefore', ' campaigns', ' sign', ' redu', ' thin', ' Because', ' wages', 'works', 'Qu', ' enormous', ' affairs', 'Pr', 'AM', ' Saf', 'iding', 'uable', 'irty', 'head', 'uto', ' obst', ' swe', ' factors', 'ORT', 'aint', ' tried', ' marketing', ' feedback', ' premium', ' Code', ' depression', ' kept', ' ability', ' inject', ' grade', ' Max', 'Then', ' honor', ' samples', ' computers', ' Hel', ' figure', ' intelligence', ' preced', ' although', ' repeat', ' sign', 'ked', ' month', ' surgery', 'view', ' therefore', ' campaigns', ' sign', ' redu', ' thin', ' Because', ' wages', 'works', 'Qu', ' enormous'], ['ox', 'gress', ' environment', 'als', 'hip', ' silver', ' Sol', ' Saint', ' categories', 'AR', ' que', ' text', ' convention', ' poverty', ' possession', ' plants', ' job', ' draft', ' Hon', ' House', ' mir', ' joke', ' designed', 'vis', 'ugg', ' First', ' cris', 'iss', 'he', ' resource', 'hab', ' Ret', 'IR', ' estim', ' Fant', ' behavi', ' invest', 'Read', 'While', ' century', ' speed', ' hell', ' instruct', ' discount', 'allas', ' campaign', ' elig', ' family', ' broke', ' setup', 'oral', ' posted', 'ox', 'gress', ' environment', 'als', 'hip', ' silver', ' Sol', ' Saint', ' categories', 'AR', ' que', ' text', ' convention', ' poverty', ' possession', ' plants', ' job', ' draft', ' Hon', ' House', ' mir', ' joke', ' designed', 'vis', 'ugg', ' First', ' cris', 'iss', 'he', ' resource', 'hab', ' Ret', 'IR', ' estim', ' Fant', ' behavi', ' invest', 'Read', 'While', ' century', ' speed', ' hell', ' instruct', ' discount', 'allas', ' campaign', ' elig', ' family', ' broke', ' setup', 'oral', ' posted'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' listed', 'pired', ' gotten', 'eping', ' minds', 'gery', 'emic', 'style', ' MP', ' marks', ' Bush', ' Haw', ' exposed', 'secut', ' containing', ' childhood', ' slightly', ' search', 'pg', ' four', ' Women', ' cause', ' After', ' regulation', ' institution', 'iles', 'ability', ' Taylor', ' recently', 'angu', ' Jews', ' mail', ' select', 'mate', ' threats', 'agger', ' Organ', ' senior', 'action', ' field', '00000000', 'rep', ' average', 'athan', ' Pr', 'More', ' term', ' und', ' develop', ' referring', ' listed', 'pired', ' gotten', 'eping', ' minds', 'gery', 'emic', 'style', ' MP', ' marks', ' Bush', ' Haw', ' exposed', 'secut', ' containing', ' childhood', ' slightly', ' search', 'pg', ' four', ' Women', ' cause', ' After', ' regulation', ' institution', 'iles', 'ability', ' Taylor', ' recently', 'angu', ' Jews', ' mail', ' select', 'mate', ' threats', 'agger', ' Organ', ' senior', 'action', ' field', '00000000', 'rep', ' average', 'athan', ' Pr', 'More', ' term', ' und', ' develop', ' referring'], ['<|endoftext|>', '<|endoftext|>', 'ownt', ' Jersey', ' paint', ' So', ' scored', ' Ty', 'itten', ' writing', ' Since', ' talk', ' 43', 'encies', ' hyp', ' chicken', ' Albert', ' advis', ' memory', ' 43', ' OK', ' GM', ' signific', 'gu', ' Football', ' butt', ' Daily', 'isions', ' Cath', ' sequ', 'lege', 'elve', ' armed', ' Peter', ' happens', ' earlier', 'iki', 'urd', ' struggle', ' focus', 'df', 'ades', 'ch', 'ol', 'ah', 'List', ' technologies', ' De', ' Kh', ' Const', ' institution', ' estim', 'Still', 'ownt', ' Jersey', ' paint', ' So', ' scored', ' Ty', 'itten', ' writing', ' Since', ' talk', ' 43', 'encies', ' hyp', ' chicken', ' Albert', ' advis', ' memory', ' 43', ' OK', ' GM', ' signific', 'gu', ' Football', ' butt', ' Daily', 'isions', ' Cath', ' sequ', 'lege', 'elve', ' armed', ' Peter', ' happens', ' earlier', 'iki', 'urd', ' struggle', ' focus', 'df', 'ades', 'ch', 'ol', 'ah', 'List', ' technologies', ' De', ' Kh', ' Const', ' institution', ' estim', 'Still'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' Hold', ' wid', ' ages', ' members', 'ane', ' alongside', ' tow', ' Bus', ' regime', ' bare', ' view', 'ature', '66', ' Roy', ' Spring', 'Set', 'ston', ' milk', '98', ' population', 'bar', 'olve', ' >', ' tim', ' sees', ' shel', ' surveillance', ' 61', 'engers', 'orts', 'ither', ' stats', ' Report', ' Stre', ' brill', '34', ' Lou', ' nic', ' Low', '23', 'Gu', ' mult', 'This', ' stock', ' General', ' maintenance', ' Car', '05', ' weapons', 'mm', ' Hold', ' wid', ' ages', ' members', 'ane', ' alongside', ' tow', ' Bus', ' regime', ' bare', ' view', 'ature', '66', ' Roy', ' Spring', 'Set', 'ston', ' milk', '98', ' population', 'bar', 'olve', ' >', ' tim', ' sees', ' shel', ' surveillance', ' 61', 'engers', 'orts', 'ither', ' stats', ' Report', ' Stre', ' brill', '34', ' Lou', ' nic', ' Low', '23', 'Gu', ' mult', 'This', ' stock', ' General', ' maintenance', ' Car', '05', ' weapons', 'mm'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' neither', ' duty', ' Louis', ' city', ' reported', 'enger', ' came', ' Bas', ' x', 'osh', 'teen', ' acqu', ' il', 'ga', 'aste', ' sing', ' context', ' spin', ' spec', ' eligible', ' element', 'otic', ' Nation', ' Fore', ' explained', 'ologists', '########', ' gle', 'b', ' pace', ' subsection', ' Henry', 'anged', 'arrass', ' city', 'eder', ' danger', ' ;', ' Christ', 'owing', ' Des', ' ;', ' Law', 'aska', ' lines', '29', 'met', ' cook', 'EO', ' support', ' neither', ' duty', ' Louis', ' city', ' reported', 'enger', ' came', ' Bas', ' x', 'osh', 'teen', ' acqu', ' il', 'ga', 'aste', ' sing', ' context', ' spin', ' spec', ' eligible', ' element', 'otic', ' Nation', ' Fore', ' explained', 'ologists', '########', ' gle', 'b', ' pace', ' subsection', ' Henry', 'anged', 'arrass', ' city', 'eder', ' danger', ' ;', ' Christ', 'owing', ' Des', ' ;', ' Law', 'aska', ' lines', '29', 'met', ' cook', 'EO', ' support'], ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ' However', 'anny', 'esc', '67', ' init', ' satisf', 'Like', ' driving', ' seek', ' Awoken', ' Still', ' Town', ' reverse', ' vehicles', 'list', ' estimates', ' accord', ' extra', ' adds', ' parts', ' maintain', ' uncertain', ' fired', ' Bush', ' Israeli', ' arrived', ' pretty', ' Offic', ' narrow', ' overwhelming', ' Hop', ' tiss', 'pected', ' shooting', ' Med', 'organ', 'Level', 'inc', 'Pol', 'bre', ' spect', ' tweet', 'ords', 'oral', ' wake', ' employed', 'CC', ' further', ' legislation', ' regul', ' However', 'anny', 'esc', '67', ' init', ' satisf', 'Like', ' driving', ' seek', ' Awoken', ' Still', ' Town', ' reverse', ' vehicles', 'list', ' estimates', ' accord', ' extra', ' adds', ' parts', ' maintain', ' uncertain', ' fired', ' Bush', ' Israeli', ' arrived', ' pretty', ' Offic', ' narrow', ' overwhelming', ' Hop', ' tiss', 'pected', ' shooting', ' Med', 'organ', 'Level', 'inc', 'Pol', 'bre', ' spect', ' tweet', 'ords', 'oral', ' wake', ' employed', 'CC', ' further', ' legislation', ' regul']]\n" - ] - } - ], - "source": [ - "import einops\n", - "\n", - "batch_size = 10\n", - "seq_len = 50\n", - "random_tokens = torch.randint(1000, 10000, (batch_size, seq_len))\n", - "repeated_tokens = einops.repeat(random_tokens, \"batch seq_len -> batch (2 seq_len)\")\n", - "\n", - "with model.generate(device_map='cuda:0', max_new_tokens=1) as generator:\n", - "\n", - " with generator.invoke(repeated_tokens) as invoker:\n", - "\n", - " token_ids = invoker.ids\n", - "\n", - " logits = model.lm_head.output.save()\n", - "\n", - "loss = cross_entropy_loss(logits.value, token_ids, shift=True, avg_token=False)\n", - "\n", - "line(loss, xaxis=\"Position\", yaxis=\"Loss\", title=\"Loss by position on random repeated tokens\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QC2rve4Pchil" - }, - "source": [ - "The induction heads will be attending from the second occurence of each token to the token *after* its first occurence, ie the token `50-1==49` places back. So by looking at the average attention paid 49 tokens back, we can identify induction heads! Let's define a hook to do this!\n", - "\n", - "
Technical details\n", - "\n", - "* We attach the hook to the attention pattern activation. There's one big pattern activation per layer, stacked across all heads, so we need to do some tensor manipulation to get a per-head score.\n", - "* Hook functions can access global state, so we make a big tensor to store the induction head score for each head, and then we just add the score for each head to the appropriate position in the tensor.\n", - "* To get a single hook function that works for each layer, we use the `hook.layer()` method to get the layer index (internally this is just inferred from the hook names).\n", - "* As we want to add this to *every* activation pattern hook point, rather than giving the string for an activation name, this time we give a **name filter**. This is a Boolean function on hook point names, and it adds the hook function to every hook point where the function evaluates as true.\n", - " * `run_with_hooks` allows us to enter a list of (act_name, hook_function) pairs to all be added at once, so we could also have done this by inputting a list with a hook for each layer.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": { - "id": "sR3Vyu_ychil", - "outputId": "5fe99985-3ea3-4cd4-db7a-44688122db80" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 12])\n" - ] - } - ], - "source": [ - "\n", - "with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n", - "\n", - " with generator.invoke(repeated_tokens, output_attentions=True) as invoker:\n", - "\n", - " attn_hidden_states = [model.transformer.h[layer_idx].attn.output[2].save() for layer_idx in range(len(model.transformer.h))]\n", - "\n", - "attn_hidden_states = torch.stack([hs.value for hs in attn_hidden_states]).diagonal(dim1=-2, dim2=-1, offset=1-seq_len)\n", - "induction_score = attn_hidden_states.mean(1).mean(-1).cpu()\n", - "\n", - "imshow(induction_score, xaxis=\"Head\", yaxis=\"Layer\", title=\"Induction Score by Head\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gRn5MlWKchil" - }, - "source": [ - "Head 5 in Layer 5 scores extremely highly on this score, and we can feed in a shorter repeated random sequence, visualize the attention pattern for it and see this directly - including the \"induction stripe\" at `seq_len-1` tokens back.\n", - "\n", - "This time we put in a hook on the attention pattern activation to visualize the pattern of the relevant head." - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": { - "id": "AE-Pozddchim", - "outputId": "f61aacda-2cf4-4ced-a7d5-ca6de8cd1724" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 81, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "induction_head_layer = 5\n", - "induction_head_index = 5\n", - "single_random_sequence = torch.randint(1000, 10000, (1, 20))\n", - "repeated_random_sequence = einops.repeat(single_random_sequence, \"batch seq_len -> batch (2 seq_len)\")\n", - "\n", - "with model.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:\n", - "\n", - " with generator.invoke(repeated_random_sequence, output_attentions=True) as invoker:\n", - "\n", - " attn_hidden_states = model.transformer.h[induction_head_layer].attn.output[2][:, induction_head_index].save()\n", - "\n", - "cv.attention.attention_patterns(\n", - " tokens=invoker.tokens, \n", - " attention=attn_hidden_states.value\n", - " )\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppCsXi6echim" - }, - "source": [ - "## Available Models" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tO5xySTFchim" - }, - "source": [ - "TransformerLens comes with over 40 open source models available, all of which can be loaded into a consistent(-ish) architecture by just changing the name in `from_pretrained`. The open source models available are [documented here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=jHj79Pj58cgJKdq4t-ygK-4h), and a set of interpretability friendly models I've trained are [documented here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=NCJ6zH_Okw_mUYAwGnMKsj2m), including a set of toy language models (tiny one to four layer models) and a set of [SoLU models](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=FZ5W6GGcy6OitPEaO733JLqf) up to GPT-2 Medium size (300M parameters). You can see [a table of the official alias and hyper-parameters of available models here](https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/model_properties_table.md).\n", - "\n", - "**Note:** TransformerLens does not currently support multi-GPU models (which you want for models above eg 7B parameters), but this feature is coming soon!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mcVVpRnechim" - }, - "source": [ - "\n", - "Notably, this means that analysis can be near immediately re-run on a different model by just changing the name - to see this, let's load in DistilGPT-2 (a distilled version of GPT-2, with half as many layers) and copy the code from above to see the induction heads in that model." - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": { - "id": "VOAKGrCkchim", - "outputId": "c93cef32-0d8a-47d7-859d-d5b5e4b1a022" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading model.safetensors: 100%|██████████| 353M/353M [00:04<00:00, 72.3MB/s] \n", - "Downloading (…)neration_config.json: 100%|██████████| 124/124 [00:00<00:00, 564kB/s]\n", - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - } - ], - "source": [ - "distilgpt2 = engine.LanguageModel('distilgpt2', device_map='cuda:0')\n", - "\n", - "with distilgpt2.generate(max_new_tokens=1, output_attentions=True) as generator:\n", - "\n", - " with generator.invoke(repeated_tokens, output_attentions=True) as invoker:\n", - "\n", - " attn_hidden_states = [distilgpt2.transformer.h[layer_idx].attn.output[2].save() for layer_idx in range(len(distilgpt2.transformer.h))]\n", - "\n", - "attn_hidden_states = torch.stack([hs.value for hs in attn_hidden_states]).diagonal(dim1=-2, dim2=-1, offset=1-seq_len)\n", - "induction_score = attn_hidden_states.mean(1).mean(-1).cpu()\n", - "\n", - "imshow(induction_score, xaxis=\"Head\", yaxis=\"Layer\", title=\"Induction Score by Head in Distil GPT-2\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s17XfS7_chim" - }, - "source": [ - "\n", - "### An overview of the important open source models in the library\n", - "\n", - "* **GPT-2** - the classic generative pre-trained models from OpenAI\n", - " * Sizes Small (85M), Medium (300M), Large (700M) and XL (1.5B).\n", - " * Trained on ~22B tokens of internet text. ([Open source replication](https://huggingface.co/datasets/openwebtext))\n", - "* **GPT-Neo** - Eleuther's replication of GPT-2\n", - " * Sizes 125M, 1.3B, 2.7B\n", - " * Trained on 300B(ish?) tokens of [the Pile](https://pile.eleuther.ai/) a large and diverse dataset including a bunch of code (and weird stuff)\n", - "* **[OPT](https://ai.facebook.com/blog/democratizing-access-to-large-scale-language-models-with-opt-175b/)** - Meta AI's series of open source models\n", - " * Trained on 180B tokens of diverse text.\n", - " * 125M, 1.3B, 2.7B, 6.7B, 13B, 30B, 66B\n", - "* **GPT-J** - Eleuther's 6B parameter model, trained on the Pile\n", - "* **GPT-NeoX** - Eleuther's 20B parameter model, trained on the Pile\n", - "* **StableLM** - Stability AI's 3B and 7B models, with and without chat and instruction fine-tuning\n", - "* **Stanford CRFM models** - a replication of GPT-2 Small and GPT-2 Medium, trained on 5 different random seeds.\n", - " * Notably, 600 checkpoints were taken during training per model, and these are available in the library with eg `HookedTransformer.from_pretrained(\"stanford-gpt2-small-a\", checkpoint_index=265)`.\n", - "- **BERT** - Google's bidirectional encoder-only transformer.\n", - " - Size Base (108M), trained on English Wikipedia and BooksCorpus.\n", - "\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x1yD66Uhchim" - }, - "source": [ - "\n", - "### An overview of some interpretability-friendly models I've trained and included\n", - "\n", - "(Feel free to [reach out](mailto:neelnanda27@gmail.com) if you want more details on any of these models)\n", - "\n", - "Each of these models has about ~200 checkpoints taken during training that can also be loaded from TransformerLens, with the `checkpoint_index` argument to `from_pretrained`.\n", - "\n", - "Note that all models are trained with a Beginning of Sequence token, and will likely break if given inputs without that!\n", - "\n", - "* **Toy Models**: Inspired by [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html), I've trained 12 tiny language models, of 1-4L and each of width 512. I think that interpreting these is likely to be far more tractable than larger models, and both serve as good practice and will likely contain motifs and circuits that generalise to far larger models (like induction heads):\n", - " * Attention-Only models (ie without MLPs): attn-only-1l, attn-only-2l, attn-only-3l, attn-only-4l\n", - " * GELU models (ie with MLP, and the standard GELU activations): gelu-1l, gelu-2l, gelu-3l, gelu-4l\n", - " * SoLU models (ie with MLP, and [Anthropic's SoLU activation](https://transformer-circuits.pub/2022/solu/index.html), designed to make MLP neurons more interpretable): solu-1l, solu-2l, solu-3l, solu-4l\n", - " * All models are trained on 22B tokens of data, 80% from C4 (web text) and 20% from Python Code\n", - " * Models of the same layer size were trained with the same weight initialization and data shuffle, to more directly compare the effect of different activation functions.\n", - "* **SoLU** models: A larger scan of models trained with [Anthropic's SoLU activation](https://transformer-circuits.pub/2022/solu/index.html), in the hopes that it makes the MLP neuron interpretability easier.\n", - " * A scan up to GPT-2 Medium size, trained on 30B tokens of the same data as toy models, 80% from C4 and 20% from Python code.\n", - " * solu-6l (40M), solu-8l (100M), solu-10l (200M), solu-12l (340M)\n", - " * An older scan up to GPT-2 Medium size, trained on 15B tokens of [the Pile](https://pile.eleuther.ai/)\n", - " * solu-1l-pile (13M), solu-2l-pile (13M), solu-4l-pile (13M), solu-6l-pile (40M), solu-8l-pile (100M), solu-10l-pile (200M), solu-12l-pile (340M)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hdxiWIGzchim" - }, - "source": [ - "## Other Resources:\n", - "\n", - "* [Concrete Steps to Get Started in Mechanistic Interpretability](https://neelnanda.io/getting-started): A guide I wrote for how to get involved in mechanistic interpretability, and how to learn the basic skills\n", - "* [A Comprehensive Mechanistic Interpretability Explainer](https://neelnanda.io/glossary): An overview of concepts in the field and surrounding ideas in ML and transformers, with long digressions to give context and build intuitions.\n", - "* [Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems), a doc I wrote giving a long list of open problems in mechanistic interpretability, and thoughts on how to get started on trying to work on them.\n", - " * There's a lot of low-hanging fruit in the field, and I expect that many people reading this could use TransformerLens to usefully make progress on some of these!\n", - "* Other demos:\n", - " * **[Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo)**, a demonstration of my standard toolkit for how to use TransformerLens to explore a mysterious behaviour in a language model.\n", - " * [Interpretability in the Wild](https://github.com/redwoodresearch/Easy-Transformer) a codebase from Arthur Conmy and Alex Variengien at Redwood research using this library to do a detailed and rigorous reverse engineering of the Indirect Object Identification circuit, to accompany their paper\n", - " * Note - this was based on an earlier version of this library, called EasyTransformer. It's pretty similar, but several breaking changes have been made since.\n", - " * A [recorded walkthrough](https://www.youtube.com/watch?v=yo4QvDn-vsU) of me doing research with TransformerLens on whether a tiny model can re-derive positional information, with [an accompanying Colab](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/No_Position_Experiment.ipynb)\n", - "* [Neuroscope](https://neuroscope.io), a website showing the text in the dataset that most activates each neuron in some selected models. Good to explore to get a sense for what kind of features the model tends to represent, and as a \"wiki\" to get some info\n", - " * A tutorial on how to make an [Interactive Neuroscope](https://github.com/neelnanda-io/TransformerLens/blob/main/Hacky-Interactive-Lexoscope.ipynb), where you type in text and see the neuron activations over the text update live." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AK5zZTdQchin" - }, - "source": [ - "## Transformer architecture\n", - "\n", - "HookedTransformer is a somewhat adapted GPT-2 architecture, but is computationally identical. The most significant changes are to the internal structure of the attention heads:\n", - "* The weights (W_K, W_Q, W_V) mapping the residual stream to queries, keys and values are 3 separate matrices, rather than big concatenated one.\n", - "* The weight matrices (W_K, W_Q, W_V, W_O) and activations (keys, queries, values, z (values mixed by attention pattern)) have separate head_index and d_head axes, rather than flattening them into one big axis.\n", - " * The activations all have shape `[batch, position, head_index, d_head]`\n", - " * W_K, W_Q, W_V have shape `[head_index, d_model, d_head]` and W_O has shape `[head_index, d_head, d_model]`\n", - "\n", - "The actual code is a bit of a mess, as there's a variety of Boolean flags to make it consistent with the various different model families in TransformerLens - to understand it and the internal structure, I instead recommend reading the code in [CleanTransformerDemo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tgwA_CuEchin" - }, - "source": [ - "### Parameter Names\n", - "\n", - "Here is a list of the parameters and shapes in the model. By convention, all weight matrices multiply on the right (ie `new_activation = old_activation @ weights + bias`).\n", - "\n", - "Reminder of the key hyper-params:\n", - "* `n_layers`: 12. The number of transformer blocks in the model (a block contains an attention layer and an MLP layer)\n", - "* `n_heads`: 12. The number of attention heads per attention layer\n", - "* `d_model`: 768. The residual stream width.\n", - "* `d_head`: 64. The internal dimension of an attention head activation.\n", - "* `d_mlp`: 3072. The internal dimension of the MLP layers (ie the number of neurons).\n", - "* `d_vocab`: 50267. The number of tokens in the vocabulary.\n", - "* `n_ctx`: 1024. The maximum number of tokens in an input prompt.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8PUN996Fchin" - }, - "source": [ - "**Transformer Block parameters:**\n", - "Replace 0 with the relevant layer index." - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": { - "id": "b73tqyTUchin", - "outputId": "756a3064-2672-44ab-bd59-701cf2d06e37" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "transformer.h.0.ln_1.weight torch.Size([768])\n", - "transformer.h.0.ln_1.bias torch.Size([768])\n", - "transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])\n", - "transformer.h.0.attn.c_attn.bias torch.Size([2304])\n", - "transformer.h.0.attn.c_proj.weight torch.Size([768, 768])\n", - "transformer.h.0.attn.c_proj.bias torch.Size([768])\n", - "transformer.h.0.ln_2.weight torch.Size([768])\n", - "transformer.h.0.ln_2.bias torch.Size([768])\n", - "transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])\n", - "transformer.h.0.mlp.c_fc.bias torch.Size([3072])\n", - "transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])\n", - "transformer.h.0.mlp.c_proj.bias torch.Size([768])\n" - ] - } - ], - "source": [ - "for name, param in model.named_parameters():\n", - " if name.startswith(\"transformer.h.0.\"):\n", - " print(name, param.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WATEJHrachin" - }, - "source": [ - "**Embedding & Unembedding parameters:**" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": { - "id": "LDq1TOYnchin", - "outputId": "028fe888-1da6-4d9d-e027-eab679737877" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "transformer.wte.weight torch.Size([50257, 768])\n", - "transformer.wpe.weight torch.Size([1024, 768])\n", - "transformer.ln_f.weight torch.Size([768])\n", - "transformer.ln_f.bias torch.Size([768])\n" - ] - } - ], - "source": [ - "for name, param in model.named_parameters():\n", - " if not name.startswith(\"transformer.h\"):\n", - " print(name, param.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f8Et9Fy1chin" - }, - "source": [ - "### Activation + Hook Names\n", - "\n", - "Lets get out a list of the activation/hook names in the model and their shapes. In practice, I recommend using the `utils.get_act_name` function to get the names, but this is a useful fallback, and necessary to eg write a name filter function.\n", - "\n", - "Let's do this by entering in a short, 10 token prompt, and add a hook function to each activations to print its name and shape. To avoid spam, let's just add this to activations in the first block or not in a block.\n", - "\n", - "Note 1: Each LayerNorm has a hook for the scale factor (ie the standard deviation of the input activations for each token position & batch element) and for the normalized output (ie the input activation with mean 0 and standard deviation 1, but *before* applying scaling or translating with learned weights). LayerNorm is applied every time a layer reads from the residual stream: `ln1` is the LayerNorm before the attention layer in a block, `ln2` the one before the MLP layer, and `ln_final` is the LayerNorm before the unembed.\n", - "\n", - "Note 2: *Every* activation apart from the attention pattern and attention scores has shape beginning with `[batch, position]`. The attention pattern and scores have shape `[batch, head_index, dest_position, source_position]` (the numbers are the same, unless we're using caching)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ppEUngzGchin", - "outputId": "5bc1a732-9582-4adb-dbed-c6cac46338c2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Num tokens: 1\n", - "hook_embed torch.Size([1, 10, 768])\n", - "hook_pos_embed torch.Size([1, 10, 768])\n", - "blocks.0.hook_resid_pre torch.Size([1, 10, 768])\n", - "blocks.0.ln1.hook_scale torch.Size([1, 10, 1])\n", - "blocks.0.ln1.hook_normalized torch.Size([1, 10, 768])\n", - "blocks.0.attn.hook_q torch.Size([1, 10, 12, 64])\n", - "blocks.0.attn.hook_k torch.Size([1, 10, 12, 64])\n", - "blocks.0.attn.hook_v torch.Size([1, 10, 12, 64])\n", - "blocks.0.attn.hook_attn_scores torch.Size([1, 12, 10, 10])\n", - "blocks.0.attn.hook_pattern torch.Size([1, 12, 10, 10])\n", - "blocks.0.attn.hook_z torch.Size([1, 10, 12, 64])\n", - "blocks.0.hook_attn_out torch.Size([1, 10, 768])\n", - "blocks.0.hook_resid_mid torch.Size([1, 10, 768])\n", - "blocks.0.ln2.hook_scale torch.Size([1, 10, 1])\n", - "blocks.0.ln2.hook_normalized torch.Size([1, 10, 768])\n", - "blocks.0.mlp.hook_pre torch.Size([1, 10, 3072])\n", - "blocks.0.mlp.hook_post torch.Size([1, 10, 3072])\n", - "blocks.0.hook_mlp_out torch.Size([1, 10, 768])\n", - "blocks.0.hook_resid_post torch.Size([1, 10, 768])\n", - "ln_final.hook_scale torch.Size([1, 10, 1])\n", - "ln_final.hook_normalized torch.Size([1, 10, 768])\n" - ] - } - ], - "source": [ - "test_prompt = \"The quick brown fox jumped over the lazy dog\"\n", - "print(\"Num tokens:\", len(model.to_tokens(test_prompt)[0]))\n", - "\n", - "def print_name_shape_hook_function(activation, hook):\n", - " print(hook.name, activation.shape)\n", - "\n", - "not_in_late_block_filter = lambda name: name.startswith(\"blocks.0.\") or not name.startswith(\"blocks\")\n", - "\n", - "model.run_with_hooks(\n", - " test_prompt,\n", - " return_type=None,\n", - " fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IKlgIxhVchio" - }, - "source": [ - "### Folding LayerNorm (For the Curious)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "whdOCecfchio" - }, - "source": [ - "(For the curious - this is an important technical detail that's worth understanding, especially if you have preconceptions about how transformers work, but not necessary to use TransformerLens)\n", - "\n", - "LayerNorm is a normalization technique used by transformers, analogous to BatchNorm but more friendly to massive parallelisation. No one *really* knows why it works, but it seems to improve model numerical stability. Unlike BatchNorm, LayerNorm actually changes the functional form of the model, which makes it a massive pain for interpretability!\n", - "\n", - "Folding LayerNorm is a technique to make it lower overhead to deal with, and the flags `center_writing_weights` and `fold_ln` in `HookedTransformer.from_pretrained` apply this automatically (they default to True). These simplify the internal structure without changing the weights.\n", - "\n", - "Intuitively, LayerNorm acts on each residual stream vector (ie for each batch element and token position) independently, sets their mean to 0 (centering) and standard deviation to 1 (normalizing) (*across* the residual stream dimension - very weird!), and then applies a learned elementwise scaling and translation to each vector.\n", - "\n", - "Mathematically, centering is a linear map, normalizing is *not* a linear map, and scaling and translation are linear maps.\n", - "* **Centering:** LayerNorm is applied every time a layer reads from the residual stream, so the mean of any residual stream vector can never matter - `center_writing_weights` set every weight matrix writing to the residual to have zero mean.\n", - "* **Normalizing:** Normalizing is not a linear map, and cannot be factored out. The `hook_scale` hook point lets you access and control for this.\n", - "* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation.\n", - "\n", - "[See the docs for more details](https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aoxoAJszchio" - }, - "source": [ - "A fun consequence of LayerNorm folding is that it creates a bias across the unembed, a `d_vocab` length vector that is added to the output logits - GPT-2 is not trained with this, but it *is* trained with a final LayerNorm that contains a bias.\n", - "\n", - "Turns out, this LayerNorm bias learns structure of the data that we can only see after folding! In particular, it essentially learns **unigram statistics** - rare tokens get suppressed, common tokens get boosted, by pretty dramatic degrees! Let's list the top and bottom 20 - at the top we see common punctuation and words like \" the\" and \" and\", at the bottom we see weird-ass tokens like \" RandomRedditor\":" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_ANS4ivBchio" - }, - "outputs": [], - "source": [ - "unembed_bias = model.unembed.b_U\n", - "bias_values, bias_indices = unembed_bias.sort(descending=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mHyajC_Echio", - "outputId": "6cb3bcf9-ffab-4d90-c6b1-17e0f7d3eef5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top 20 values\n", - "7.03 ','\n", - "6.98 ' the'\n", - "6.68 ' and'\n", - "6.49 '.'\n", - "6.48 '\\n'\n", - "6.47 ' a'\n", - "6.41 ' in'\n", - "6.25 ' to'\n", - "6.16 ' of'\n", - "6.04 '-'\n", - "6.03 ' ('\n", - "5.88 ' \"'\n", - "5.80 ' for'\n", - "5.72 ' that'\n", - "5.64 ' on'\n", - "5.59 ' is'\n", - "5.52 ' as'\n", - "5.49 ' at'\n", - "5.45 ' with'\n", - "5.44 ' or'\n", - "...\n", - "Bottom 20 values\n", - "-3.82 ' サーティ'\n", - "-3.83 '\\x18'\n", - "-3.83 '\\x14'\n", - "-3.83 ' RandomRedditor'\n", - "-3.83 '龍�'\n", - "-3.83 '�'\n", - "-3.83 '\\x1b'\n", - "-3.83 '�'\n", - "-3.83 '\\x05'\n", - "-3.83 '\\x00'\n", - "-3.83 '\\x06'\n", - "-3.83 '\\x07'\n", - "-3.83 '\\x0c'\n", - "-3.83 '\\x02'\n", - "-3.83 'oreAndOnline'\n", - "-3.84 '\\x11'\n", - "-3.84 '�'\n", - "-3.84 '\\x10'\n", - "-3.84 '�'\n", - "-3.84 '�'\n" - ] - } - ], - "source": [ - "top_k = 20\n", - "print(f\"Top {top_k} values\")\n", - "for i in range(top_k):\n", - " print(f\"{bias_values[i].item():.2f} {repr(model.to_string(bias_indices[i]))}\")\n", - "\n", - "print(\"...\")\n", - "print(f\"Bottom {top_k} values\")\n", - "for i in range(top_k, 0, -1):\n", - " print(f\"{bias_values[-i].item():.2f} {repr(model.to_string(bias_indices[-i]))}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "92xC7NZOchio" - }, - "source": [ - "This can have real consequences for interpretability - for example, this bias favours \" John\" over \" Mary\" by about 1.2, about 1/3 of the effect size of the Indirect Object Identification Circuit! All other things being the same, this makes the John token 3.6x times more likely than the Mary token." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XKgDLN9echio", - "outputId": "6f06117b-fda0-4b5b-d381-5a276f7ed037" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "John bias: 2.8995\n", - "Mary bias: 1.6034\n", - "Prob ratio bias: 3.6550x\n" - ] - } - ], - "source": [ - "john_bias = model.unembed.b_U[model.to_single_token(' John')]\n", - "mary_bias = model.unembed.b_U[model.to_single_token(' Mary')]\n", - "\n", - "print(f\"John bias: {john_bias.item():.4f}\")\n", - "print(f\"Mary bias: {mary_bias.item():.4f}\")\n", - "print(f\"Prob ratio bias: {torch.exp(john_bias - mary_bias).item():.4f}x\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7LxGgL6ychip" - }, - "source": [ - "# Features\n", - "\n", - "An overview of some other important features of the library. I recommend checking out the [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb) for some other important features not mentioned here, and for a demo of what using the library in practice looks like." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ye2pwnBzchip" - }, - "source": [ - "## Dealing with tokens\n", - "\n", - "**Tokenization** is one of the most annoying features of studying language models. We want language models to be able to take in arbitrary text as input, but the transformer architecture needs the inputs to be elements of a fixed, finite vocabulary. The solution to this is **tokens**, a fixed vocabulary of \"sub-words\", that any natural language can be broken down into with a **tokenizer**. This is invertible, and we can recover the original text, called **de-tokenization**.\n", - "\n", - "TransformerLens comes with a range of utility functions to deal with tokenization. Different models can have different tokenizers, so these are all methods on the model.\n", - "\n", - "get_token_position, to_tokens, to_string, to_str_tokens, prepend_bos, to_single_token" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zz8pwzh3chip" - }, - "source": [ - "The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\n", - "\n", - "Some observations - there are a lot of arbitrary-ish details in here!\n", - "* The tokenizer splits on spaces, so no token contains two words.\n", - "* Tokens include the preceding space, and whether the first token is a capital letter. `how` and ` how` are different tokens!\n", - "* Common words are single tokens, even if fairly long (` paragraph`) while uncommon words are split into multiple tokens (` token|ized`).\n", - "* Tokens *mostly* split on punctuation characters (eg `*` and `.`), but eg `'s` is a single token." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ldX_temJchip", - "outputId": "275a1a4d-b123-49ab-dbac-f51959ad141f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['<|endoftext|>', 'The', ' first', ' thing', ' you', ' need', ' to', ' figure', ' out', ' is', ' *', 'how', '*', ' things', ' are', ' token', 'ized', '.', ' `', 'model', '.', 'to', '_', 'str', '_', 't', 'ok', 'ens', '`', ' splits', ' a', ' string', ' into', ' the', ' tokens', ' *', 'as', ' a', ' list', ' of', ' sub', 'strings', '*,', ' and', ' so', ' lets', ' you', ' explore', ' what', ' the', ' text', ' looks', ' like', '.', ' To', ' demonstrate', ' this', ',', ' let', \"'s\", ' use', ' it', ' on', ' this', ' paragraph', '.']\n" - ] - } - ], - "source": [ - "example_text = \"The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\"\n", - "example_text_str_tokens = model.to_str_tokens(example_text)\n", - "print(example_text_str_tokens)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D4-Y34k2chip" - }, - "source": [ - "The transformer needs to take in a sequence of integers, not strings, so we need to convert these tokens into integers. `model.to_tokens` does this, and returns a tensor of integers on the model's device (shape `[batch, position]`). It maps a string to a batch of size 1." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Z11OTL6hchip", - "outputId": "c06d1f73-68a8-4d5e-9cb6-42c0b326b98b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[50256, 464, 717, 1517, 345, 761, 284, 3785, 503, 318,\n", - " 1635, 4919, 9, 1243, 389, 11241, 1143, 13, 4600, 19849,\n", - " 13, 1462, 62, 2536, 62, 83, 482, 641, 63, 30778,\n", - " 257, 4731, 656, 262, 16326, 1635, 292, 257, 1351, 286,\n", - " 850, 37336, 25666, 290, 523, 8781, 345, 7301, 644, 262,\n", - " 2420, 3073, 588, 13, 1675, 10176, 428, 11, 1309, 338,\n", - " 779, 340, 319, 428, 7322, 13]], device='cuda:0')\n" - ] - } - ], - "source": [ - "example_text_tokens = model.to_tokens(example_text)\n", - "print(example_text_tokens)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eWz57zf7chip" - }, - "source": [ - "`to_tokens` can also take in a list of strings, and return a batch of size `len(strings)`. If the strings are different numbers of tokens, it adds a PAD token to the end of the shorter strings to make them the same length.\n", - "\n", - "(Note: In GPT-2, 50256 signifies both the beginning of sequence, end of sequence and padding token - see the `prepend_bos` section for details)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nQBTaHXxchiq", - "outputId": "fa374758-6483-4598-bc4b-459996d16907" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[50256, 464, 3797, 3332, 319, 262, 2603, 13, 50256, 50256],\n", - " [50256, 464, 3797, 3332, 319, 262, 2603, 1107, 1327, 13]],\n", - " device='cuda:0')\n" - ] - } - ], - "source": [ - "example_multi_text = [\"The cat sat on the mat.\", \"The cat sat on the mat really hard.\"]\n", - "example_multi_text_tokens = model.to_tokens(example_multi_text)\n", - "print(example_multi_text_tokens)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0_-XIbPbchiq" - }, - "source": [ - "`model.to_single_token` is a convenience function that takes in a string corresponding to a *single* token and returns the corresponding integer. This is useful for eg looking up the logit corresponding to a single token.\n", - "\n", - "For example, let's input `The cat sat on the mat.` to GPT-2, and look at the log prob predicting that the next token is ` The`.\n", - "\n", - "
Technical notes\n", - "\n", - "Note that if we input a string to the model, it's implicitly converted to a string with `to_tokens`.\n", - "\n", - "Note further that the log probs have shape `[batch, position, d_vocab]==[1, 8, 50257]`, with a vector of log probs predicting the next token for *every* token position. GPT-2 uses causal attention which means heads can only look backwards (equivalently, information can only move forwards in the model.), so the log probs at position k are only a function of the first k tokens, and it can't just cheat and look at the k+1 th token. This structure lets it generate text more efficiently, and lets it treat every *token* as a training example, rather than every *sequence*.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Qig1QDHuchiq", - "outputId": "ff747144-6937-47c1-e705-ebe606bdbb15" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Probability tensor shape [batch, position, d_vocab] == torch.Size([1, 8, 50257])\n", - "| The| probability: 11.98%\n" - ] - } - ], - "source": [ - "cat_text = \"The cat sat on the mat.\"\n", - "cat_logits = model(cat_text)\n", - "cat_probs = cat_logits.softmax(dim=-1)\n", - "print(f\"Probability tensor shape [batch, position, d_vocab] == {cat_probs.shape}\")\n", - "\n", - "capital_the_token_index = model.to_single_token(\" The\")\n", - "print(f\"| The| probability: {cat_probs[0, -1, capital_the_token_index].item():.2%}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d5QFH48kchiq" - }, - "source": [ - "`model.to_string` is the inverse of `to_tokens` and maps a tensor of integers to a string or list of strings. It also works on integers and lists of integers.\n", - "\n", - "For example, let's look up token 256 (due to technical details of tokenization, this will be the most common pair of ASCII characters!), and also verify that our tokens above map back to a string." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "A6TYcdpYchiq", - "outputId": "8ebba33e-d363-4b93-a97b-e654eea1d6fc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token 256 - the most common pair of ASCII characters: | t|\n", - "De-Tokenizing the example tokens: <|endoftext|>The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\n" - ] - } - ], - "source": [ - "print(f\"Token 256 - the most common pair of ASCII characters: |{model.to_string(256)}|\")\n", - "# Squeeze means to remove dimensions of length 1.\n", - "# Here, that removes the dummy batch dimension so it's a rank 1 tensor and returns a string\n", - "# Rank 2 tensors map to a list of strings\n", - "print(f\"De-Tokenizing the example tokens: {model.to_string(example_text_tokens.squeeze())}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8dEN09Kechiq" - }, - "source": [ - "A related annoyance of tokenization is that it's hard to figure out how many tokens a string will break into. `model.get_token_position(single_token, tokens)` returns the position of `single_token` in `tokens`. `tokens` can be either a string or a tensor of tokens.\n", - "\n", - "Note that position is zero-indexed, it's two (ie third) because there's a beginning of sequence token automatically prepended (see the next section for details)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SAcy7brpchir", - "outputId": "1bfbf6ce-83c9-456c-c724-054c8f44e297" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "With BOS: 2\n", - "Without BOS: 1\n" - ] - } - ], - "source": [ - "print(\"With BOS:\", model.get_token_position(\" cat\", \"The cat sat on the mat\"))\n", - "print(\"Without BOS:\", model.get_token_position(\" cat\", \"The cat sat on the mat\", prepend_bos=False))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VyjeTxgGchir" - }, - "source": [ - "If there are multiple copies of the token, we can set `mode=\"first\"` to find the first occurence's position and `mode=\"last\"` to find the last" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LHM4NfHqchir", - "outputId": "d3d7c62c-3b21-4134-a668-6a222fdf0ffc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "First occurence 2\n", - "Final occurence 13\n" - ] - } - ], - "source": [ - "print(\"First occurence\", model.get_token_position(\n", - " \" cat\",\n", - " \"The cat sat on the mat. The mat sat on the cat.\",\n", - " mode=\"first\"))\n", - "print(\"Final occurence\", model.get_token_position(\n", - " \" cat\",\n", - " \"The cat sat on the mat. The mat sat on the cat.\",\n", - " mode=\"last\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ivuNrPOichir" - }, - "source": [ - "In general, tokenization is a pain, and full of gotchas. I highly recommend just playing around with different inputs and their tokenization and getting a feel for it. As another \"fun\" example, let's look at the tokenization of arithmetic expressions - tokens do *not* contain consistent numbers of digits. (This makes it even more impressive that GPT-3 can do arithmetic!)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zhDGE02lchir", - "outputId": "4ab1f7ca-74ba-472a-ea1d-901be754a3c4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['<|endoftext|>', '23', '42', '+', '2017', '=', '214', '45']\n", - "['<|endoftext|>', '1000', '+', '1', '000000', '=', '9999', '99']\n" - ] - } - ], - "source": [ - "print(model.to_str_tokens(\"2342+2017=21445\"))\n", - "print(model.to_str_tokens(\"1000+1000000=999999\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7pl0jfqYchir" - }, - "source": [ - "I also *highly* recommend investigating prompts with easy tokenization when starting out - ideally key words should form a single token, be in the same position in different prompts, have the same total length, etc. Eg study Indirect Object Identification with common English names like ` Tim` rather than ` Ne|el`. Transformers need to spend some parameters in early layers converting multi-token words to a single feature, and then de-converting this in the late layers, and unless this is what you're explicitly investigating, this will make the behaviour you're investigating be messier." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CheGQBQLchir" - }, - "source": [ - "### Gotcha: `prepend_bos`\n", - "\n", - "Key Takeaway: **If you get weird off-by-one errors, check whether there's an unexpected `prepend_bos`!**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2tSE4szZchis" - }, - "source": [ - "A weirdness you may have noticed in the above is that `to_tokens` and `to_str_tokens` added a weird `<|endoftext|>` to the start of each prompt. TransformerLens does this by default, and it can easily trip up new users. Notably, **this includes `model.forward`** (which is what's implicitly used when you do eg `model(\"Hello World\")`). This is called a **Beginning of Sequence (BOS)** token, and it's a special token used to mark the beginning of the sequence. Confusingly, in GPT-2, the End of Sequence (EOS), Beginning of Sequence (BOS) and Padding (PAD) tokens are all the same, `<|endoftext|>` with index `50256`.\n", - "\n", - "**Gotcha:** You only want to prepend a BOS token at the *start* of a prompt. If you, eg, want to input a question followed by an answer, and want to tokenize these separately, you do *not* want to prepend_bos on the answer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z1-aWu65chis", - "outputId": "e363c436-b557-4bd8-f511-ebd5da2752c4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logits shape by default (with BOS) torch.Size([1, 3, 50257])\n", - "Logits shape with BOS torch.Size([1, 3, 50257])\n", - "Logits shape without BOS - only 2 positions! torch.Size([1, 2, 50257])\n" - ] - } - ], - "source": [ - "print(\"Logits shape by default (with BOS)\", model(\"Hello World\").shape)\n", - "print(\"Logits shape with BOS\", model(\"Hello World\", prepend_bos=True).shape)\n", - "print(\"Logits shape without BOS - only 2 positions!\", model(\"Hello World\", prepend_bos=False).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FLmwJGpKchis" - }, - "source": [ - "`prepend_bos` is a bit of a hack, and I've gone back and forth on what the correct default here is. The reason I do this is that transformers tend to treat the first token weirdly - this doesn't really matter in training (where all inputs are >1000 tokens), but this can be a big issue when investigating short prompts! The reason for this is that attention patterns are a probability distribution and so need to add up to one, so to simulate being \"off\" they normally look at the first token. Giving them a BOS token lets the heads rest by looking at that, preserving the information in the first \"real\" token.\n", - "\n", - "Further, *some* models are trained to need a BOS token (OPT and my interpretability-friendly models are, GPT-2 and GPT-Neo are not). But despite GPT-2 not being trained with this, empirically it seems to make interpretability easier.\n", - "\n", - "(However, if you want to change the default behaviour to *not* prepending a BOS token, pass `default_prepend_bos=False` when you instantiate the model, e.g., `model = HookedTransformer.from_pretrained('gpt2', default_prepend_bos=False)`.)\n", - "\n", - "For example, the model can get much worse at Indirect Object Identification without a BOS (and with a name as the first token):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UX62qZJbchis", - "outputId": "870c0238-44ae-496f-f428-007756bc58f9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logit difference with BOS: 6.754\n", - "Logit difference without BOS: 2.782\n" - ] - } - ], - "source": [ - "ioi_logits_with_bos = model(\"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=True)\n", - "mary_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(\" Mary\")].item()\n", - "claire_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(\" Claire\")].item()\n", - "print(f\"Logit difference with BOS: {(claire_logit_with_bos - mary_logit_with_bos):.3f}\")\n", - "\n", - "ioi_logits_without_bos = model(\"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=False)\n", - "mary_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(\" Mary\")].item()\n", - "claire_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(\" Claire\")].item()\n", - "print(f\"Logit difference without BOS: {(claire_logit_without_bos - mary_logit_without_bos):.3f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jPqENtELchis" - }, - "source": [ - "Though, note that this also illustrates another gotcha - when `Claire` is at the start of a sentence (no preceding space), it's actually *two* tokens, not one, which probably confuses the relevant circuit. (Note - in this test we put `prepend_bos=False`, because we want to analyse the tokenization of a specific string, not to give an input to the model!)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "l8gadRKhchis", - "outputId": "a249d48a-823f-4ed7-b5f3-aee248e5d5f3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "| Claire| -> [' Claire']\n", - "|Claire| -> ['Cl', 'aire']\n" - ] - } - ], - "source": [ - "print(f\"| Claire| -> {model.to_str_tokens(' Claire', prepend_bos=False)}\")\n", - "print(f\"|Claire| -> {model.to_str_tokens('Claire', prepend_bos=False)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jXXFDCHrchis" - }, - "source": [ - "## Factored Matrix Class\n", - "\n", - "In transformer interpretability, we often need to analyse low rank factorized matrices - a matrix $M = AB$, where M is `[large, large]`, but A is `[large, small]` and B is `[small, large]`. This is a common structure in transformers, and the `FactoredMatrix` class is a convenient way to work with these. It implements efficient algorithms for various operations on these, such as computing the trace, eigenvalues, Frobenius norm, singular value decomposition, and products with other matrices. It can (approximately) act as a drop-in replacement for the original matrix, and supports leading batch dimensions to the factored matrix.\n", - "\n", - "
Why are low-rank factorized matrices useful for transformer interpretability?\n", - "\n", - "As argued in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html), an unexpected fact about transformer attention heads is that rather than being best understood as keys, queries and values (and the requisite weight matrices), they're actually best understood as two low rank factorized matrices.\n", - "* **Where to move information from:** $W_QK = W_Q W_K^T$, used for determining the attention pattern - what source positions to move information from and what destination positions to move them to.\n", - " * Intuitively, residual stream -> query and residual stream -> key are linear maps, *and* `attention_score = query @ key.T` is a linear map, so the whole thing can be factored into one big bilinear form `residual @ W_QK @ residual.T`\n", - "* **What information to move:** $W_OV = W_V W_O$, used to determine what information to copy from the source position to the destination position (weighted by the attention pattern weight from that destination to that source).\n", - " * Intuitively, the residual stream is a `[position, d_model]` tensor (ignoring batch). The attention pattern acts on the *position* dimension (where to move information from and to) and the value and output weights act on the *d_model* dimension - ie *what* information is contained at that source position. So we can factor it all into `attention_pattern @ residual @ W_V @ W_O`, and so only need to care about `W_OV = W_V @ W_O`\n", - "* Note - the internal head dimension is smaller than the residual stream dimension, so the factorization is low rank. (here, `d_model=768` and `d_head=64`)\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ee3UuEAXchit" - }, - "source": [ - "### Basic Examples" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NV4pMu_hchit" - }, - "source": [ - "We can use the basic class directly - let's make a factored matrix directly and look at the basic operations:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6L5jqdJOchit", - "outputId": "d417fe79-1e2f-4e46-ac13-76219871759e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Norms:\n", - "tensor(5.7439)\n", - "tensor(5.7439)\n", - "Right dimension: 5, Left dimension: 5, Hidden dimension: 2\n" - ] - } - ], - "source": [ - "A = torch.randn(5, 2)\n", - "B = torch.randn(2, 5)\n", - "AB = A @ B\n", - "AB_factor = FactoredMatrix(A, B)\n", - "print(\"Norms:\")\n", - "print(AB.norm())\n", - "print(AB_factor.norm())\n", - "\n", - "print(f\"Right dimension: {AB_factor.rdim}, Left dimension: {AB_factor.ldim}, Hidden dimension: {AB_factor.mdim}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pMEqwNH_chit" - }, - "source": [ - "We can also look at the eigenvalues and singular values of the matrix. Note that, because the matrix is rank 2 but 5 by 5, the final 3 eigenvalues and singular values are zero - the factored class omits the zeros." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MgA2hK54chit", - "outputId": "3bc10ced-7f81-48ee-fe17-8cd1265995d8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Eigenvalues:\n", - "tensor([-1.6199e+00+0.0000e+00j, 1.0446e+00+0.0000e+00j,\n", - " 3.7190e-08+0.0000e+00j, -6.5998e-08+1.2859e-07j,\n", - " -6.5998e-08-1.2859e-07j])\n", - "tensor([-1.6199+0.j, 1.0446+0.j])\n", - "\n", - "Singular Values:\n", - "tensor([5.4702e+00, 1.7519e+00, 3.4613e-07, 1.0601e-07, 3.8823e-09])\n", - "tensor([5.4702, 1.7519])\n" - ] - } - ], - "source": [ - "print(\"Eigenvalues:\")\n", - "print(torch.linalg.eig(AB).eigenvalues)\n", - "print(AB_factor.eigenvalues)\n", - "print()\n", - "print(\"Singular Values:\")\n", - "print(torch.linalg.svd(AB).S)\n", - "print(AB_factor.S)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9dlG6qL6chit" - }, - "source": [ - "We can multiply with other matrices - it automatically chooses the smallest possible dimension to factor along (here it's 2, rather than 5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NH6tAGFJchit", - "outputId": "a3d9e030-00bc-4ba1-93d2-453e4b7a2818" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unfactored: torch.Size([5, 300]) tensor(99.9906)\n", - "Factored: torch.Size([5, 300]) tensor(99.9906)\n", - "Right dimension: 300, Left dimension: 5, Hidden dimension: 2\n" - ] - } - ], - "source": [ - "C = torch.randn(5, 300)\n", - "ABC = AB @ C\n", - "ABC_factor = AB_factor @ C\n", - "print(\"Unfactored:\", ABC.shape, ABC.norm())\n", - "print(\"Factored:\", ABC_factor.shape, ABC_factor.norm())\n", - "print(f\"Right dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KBgr_GFichit" - }, - "source": [ - "If we want to collapse this back to an unfactored matrix, we can use the AB property to get the product:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OG-e4Elnchiu", - "outputId": "690f9589-c56e-46c6-e517-e35f2cb84f65" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(True)\n" - ] - } - ], - "source": [ - "AB_unfactored = AB_factor.AB\n", - "print(torch.isclose(AB_unfactored, AB).all())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "__NIg_npchiu" - }, - "source": [ - "### Medium Example: Eigenvalue Copying Scores\n", - "\n", - "(This is a more involved example of how to use the factored matrix class, skip it if you aren't following)\n", - "\n", - "For a more involved example, let's look at the eigenvalue copying score from [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) of the OV circuit for various heads. The OV Circuit for a head (the factorised matrix $W_OV = W_V W_O$) is a linear map that determines what information is moved from the source position to the destination position. Because this is low rank, it can be thought of as *reading in* some low rank subspace of the source residual stream and *writing to* some low rank subspace of the destination residual stream (with maybe some processing happening in the middle).\n", - "\n", - "A common operation for this will just be to *copy*, ie to have the same reading and writing subspace, and to do minimal processing in the middle. Empirically, this tends to coincide with the OV Circuit having (approximately) positive real eigenvalues. I mostly assert this as an empirical fact, but intuitively, operations that involve mapping eigenvectors to different directions (eg rotations) tend to have complex eigenvalues. And operations that preserve eigenvector direction but negate it tend to have negative real eigenvalues. And \"what happens to the eigenvectors\" is a decent proxy for what happens to an arbitrary vector.\n", - "\n", - "We can get a score for \"how positive real the OV circuit eigenvalues are\" with $\\frac{\\sum \\lambda_i}{\\sum |\\lambda_i|}$, where $\\lambda_i$ are the eigenvalues of the OV circuit. This is a bit of a hack, but it seems to work well in practice." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h7WG64Kvchiu" - }, - "source": [ - "Let's use FactoredMatrix to compute this for every head in the model! We use the helper `model.OV` to get the concatenated OV circuits for all heads across all layers in the model. This has the shape `[n_layers, n_heads, d_model, d_model]`, where `n_layers` and `n_heads` are batch dimensions and the final two dimensions are factorised as `[n_layers, n_heads, d_model, d_head]` and `[n_layers, n_heads, d_head, d_model]` matrices.\n", - "\n", - "We can then get the eigenvalues for this, where there are separate eigenvalues for each element of the batch (a `[n_layers, n_heads, d_head]` tensor of complex numbers), and calculate the copying score." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NyO7w_tUchiu", - "outputId": "d42b0c3e-0e40-441b-b90a-8a728f5ebb8a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FactoredMatrix: Shape(torch.Size([12, 12, 768, 768])), Hidden Dim(64)\n" - ] - } - ], - "source": [ - "OV_circuit_all_heads = model.OV\n", - "print(OV_circuit_all_heads)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vi4LpBExchiu", - "outputId": "a82a4e64-c31b-489a-cdd8-2118bec4bae0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 12, 64])\n", - "torch.complex64\n" - ] - } - ], - "source": [ - "OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues\n", - "print(OV_circuit_all_heads_eigenvalues.shape)\n", - "print(OV_circuit_all_heads_eigenvalues.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UBr6NntUchiu", - "outputId": "c01951e1-2290-4b94-a181-175476876241" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)\n", - "imshow(utils.to_numpy(OV_copying_score), xaxis=\"Head\", yaxis=\"Layer\", title=\"OV Copying Score for each head in GPT-2 Small\", zmax=1.0, zmin=-1.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ej_fDPXNchiu" - }, - "source": [ - "Head 11 in Layer 11 (L11H11) has a high copying score, and if we plot the eigenvalues they look approximately as expected." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xJ3brDVfchiu", - "outputId": "ec202a23-79d7-4cd9-ccd4-40cedbfd8308" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "scatter(x=OV_circuit_all_heads_eigenvalues[-1, -1, :].real, y=OV_circuit_all_heads_eigenvalues[-1, -1, :].imag, title=\"Eigenvalues of Head L11H11 of GPT-2 Small\", xaxis=\"Real\", yaxis=\"Imaginary\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4iaZ2V_Lchiv" - }, - "source": [ - "We can even look at the full OV circuit, from the input tokens to output tokens: $W_E W_V W_O W_U$. This is a `[d_vocab, d_vocab]==[50257, 50257]` matrix, so absolutely enormous, even for a single head. But with the FactoredMatrix class, we can compute the full eigenvalue copying score of every head in a few seconds." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0X0B-Z-Achiv", - "outputId": "c88641a2-794f-46d7-971b-e73f12b46eea" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FactoredMatrix: Shape(torch.Size([12, 12, 50257, 50257])), Hidden Dim(64)\n" - ] - } - ], - "source": [ - "full_OV_circuit = model.embed.W_E @ OV_circuit_all_heads @ model.unembed.W_U\n", - "print(full_OV_circuit)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sCIQpyWOchiv", - "outputId": "8a681263-c77f-41dc-93c3-dc9b63bc1f98" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 12, 64])\n", - "torch.complex64\n" - ] - } - ], - "source": [ - "full_OV_circuit_eigenvalues = full_OV_circuit.eigenvalues\n", - "print(full_OV_circuit_eigenvalues.shape)\n", - "print(full_OV_circuit_eigenvalues.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xs_9bto4chiv", - "outputId": "563c263a-fcbf-4b93-d378-8963b84238e2" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "full_OV_copying_score = full_OV_circuit_eigenvalues.sum(dim=-1).real / full_OV_circuit_eigenvalues.abs().sum(dim=-1)\n", - "imshow(utils.to_numpy(full_OV_copying_score), xaxis=\"Head\", yaxis=\"Layer\", title=\"OV Copying Score for each head in GPT-2 Small\", zmax=1.0, zmin=-1.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Md96Q2gfchiv" - }, - "source": [ - "Interestingly, these are highly (but not perfectly!) correlated. I'm not sure what to read from this, or what's up with the weird outlier heads!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kG0vZzO2chiv", - "outputId": "ed38307a-5716-4952-94ba-7c1cba4a5e1d" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "scatter(x=full_OV_copying_score.flatten(), y=OV_copying_score.flatten(), hover_name=[f\"L{layer}H{head}\" for layer in range(12) for head in range(12)], title=\"OV Copying Score for each head in GPT-2 Small\", xaxis=\"Full OV Copying Score\", yaxis=\"OV Copying Score\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "70DBpdO1chiv", - "outputId": "5a00e3cc-2b2e-4927-d10a-b93bdd9c9b0d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token 256 - the most common pair of ASCII characters: | t|\n", - "De-Tokenizing the example tokens: <|endoftext|>The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\n" - ] - } - ], - "source": [ - "print(f\"Token 256 - the most common pair of ASCII characters: |{model.to_string(256)}|\")\n", - "# Squeeze means to remove dimensions of length 1.\n", - "# Here, that removes the dummy batch dimension so it's a rank 1 tensor and returns a string\n", - "# Rank 2 tensors map to a list of strings\n", - "print(f\"De-Tokenizing the example tokens: {model.to_string(example_text_tokens.squeeze())}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GymQEk3Ichiv" - }, - "source": [ - "## Generating Text" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5sahkeQAchiw" - }, - "source": [ - "TransformerLens also has basic text generation functionality, which can be useful for generally exploring what the model is capable of (thanks to Ansh Radhakrishnan for adding this!). This is pretty rough functionality, and where possible I recommend using more established libraries like HuggingFace for this." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "9c541056beba477db0a272faec9231b5" - ] - }, - "id": "-NW8KQLHchiw", - "outputId": "595f004c-f82f-4639-e384-ac06760f8263" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9c541056beba477db0a272faec9231b5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/50 [00:00\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from transformer_lens.loading_from_pretrained import get_checkpoint_labels\n", - "for model_name in [\"attn-only-2l\", \"solu-12l\", \"stanford-gpt2-small-a\"]:\n", - " checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)\n", - " line(checkpoint_labels, xaxis=\"Checkpoint Index\", yaxis=f\"Checkpoint Value ({checkpoint_label_type})\", title=f\"Checkpoint Values for {model_name} (Log scale)\", log_y=True, markers=True)\n", - "for model_name in [\"solu-1l-pile\", \"solu-6l-pile\"]:\n", - " checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)\n", - " line(checkpoint_labels, xaxis=\"Checkpoint Index\", yaxis=f\"Checkpoint Value ({checkpoint_label_type})\", title=f\"Checkpoint Values for {model_name} (Linear scale)\", log_y=False, markers=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1sb8QnHqchix" - }, - "source": [ - "### Example: Induction Head Phase Transition" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yIkm5DjZchiy" - }, - "source": [ - "One of the more interesting results analysing circuit formation during training is the [induction head phase transition](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html). They find a pretty dramatic shift in models during training - there's a brief period where models go from not having induction heads to having them, which leads to the models suddenly becoming much better at in-context learning (using far back tokens to predict the next token, eg over 500 words back). This is enough of a big deal that it leads to a visible *bump* in the loss curve, where the model's rate of improvement briefly increases." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LVQ4Fc6jchiy" - }, - "source": [ - "As a brief demonstration of the existence of the phase transition, let's load some checkpoints of a two layer model, and see whether they have induction heads. An easy test, as we used above, is to give the model a repeated sequence of random tokens, and to check how good its loss is on the second half. `evals.induction_loss` is a rough util that runs this test on a model.\n", - "(Note - this is deliberately a rough, non-rigorous test for the purposes of demonstration, eg `evals.induction_loss` by default just runs it on 4 sequences of 384 tokens repeated twice. These results totally don't do the paper justice - go check it out if you want to see the full results!)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nHs5S9ZKchiy" - }, - "source": [ - "In the interests of time and memory, let's look at a handful of checkpoints (chosen to be around the phase change), indices `[10, 25, 35, 60, -1]`. These are roughly 22M, 200M, 500M, 1.6B and 21.8B tokens through training, respectively. (I generally recommend looking things up based on indices, rather than checkpoint value!)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Z42tKzTLchiy" - }, - "outputs": [], - "source": [ - "from transformer_lens import evals\n", - "# We use the two layer model with SoLU activations, chosen fairly arbitrarily as being both small (so fast to download and keep in memory) and pretty good at the induction task.\n", - "model_name = \"solu-2l\"\n", - "# We can load a model from a checkpoint by specifying the checkpoint_index, -1 means the final checkpoint\n", - "checkpoint_indices = [10, 25, 35, 60, -1]\n", - "checkpointed_models = []\n", - "tokens_trained_on = []\n", - "induction_losses = []" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tQ-oTk77chiy" - }, - "source": [ - "We load the models, cache them in a list, and" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PQqrP04zchiy", - "outputId": "76551392-8213-4689-9b1e-6596c8c5d14d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model solu-2l into HookedTransformer\n", - "Loaded pretrained model solu-2l into HookedTransformer\n", - "Loaded pretrained model solu-2l into HookedTransformer\n", - "Loaded pretrained model solu-2l into HookedTransformer\n", - "Loaded pretrained model solu-2l into HookedTransformer\n" - ] - } - ], - "source": [ - "for index in checkpoint_indices:\n", - " # Load the model from the relevant checkpoint by index\n", - " model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index)\n", - " checkpointed_models.append(model_for_this_checkpoint)\n", - "\n", - " tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value\n", - " tokens_trained_on.append(tokens_seen_for_this_checkpoint)\n", - "\n", - " induction_loss_for_this_checkpoint = evals.induction_loss(model_for_this_checkpoint).item()\n", - " induction_losses.append(induction_loss_for_this_checkpoint)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c3k7L7gDchiy" - }, - "source": [ - "We can plot this, and see there's a sharp shift from ~200-500M tokens trained on (note the log scale on the x axis). Interestingly, this is notably earlier than the phase transition in the paper, I'm not sure what's up with that.\n", - "\n", - "(To contextualise the numbers, the tokens in the random sequence are uniformly chosen from the first 20,000 tokens (out of ~48,000 total), so random performance is at least $\\ln(20000)\\approx 10$. A naive strategy like \"randomly choose a token that's already appeared in the first half of the sequence (384 elements)\" would get $\\ln(384)\\approx 5.95$, so the model is doing pretty well here.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-5caXkZ9chiy", - "outputId": "863a379b-9a90-401a-e3ad-cc16b6561b67" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "line(induction_losses, x=tokens_trained_on, xaxis=\"Tokens Trained On\", yaxis=\"Induction Loss\", title=\"Induction Loss over training: solu-2l\", markers=True, log_x=True)" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/examples/tl/attention.ipynb b/examples/tl/attention.ipynb deleted file mode 100644 index e31651a..0000000 --- a/examples/tl/attention.ipynb +++ /dev/null @@ -1,85 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from engine import LanguageModel\n", - "from engine.fx.Proxy import Proxy\n", - "from engine import util\n", - "\n", - "model = LanguageModel('gpt2',device_map='cuda:0')\n", - "\n", - "gpt2_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n", - "\n", - "with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n", - "\n", - " with generator.invoke(gpt2_text, output_attentions=True) as invoker:\n", - "\n", - " tokens = invoker.tokens\n", - "\n", - " attn_hidden_states = [model.transformer.h[layer_idx].attn.output[2][0].save() for layer_idx in range(len(model.transformer.h))]\n", - "\n", - "attn_hidden_states = util.apply(attn_hidden_states, lambda x : x.value, Proxy)\n", - "\n", - "import circuitsvis as cv\n", - "\n", - "cv.attention.attention_patterns(tokens=tokens, attention=attn_hidden_states[0])\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/tl/attention.py b/examples/tl/attention.py deleted file mode 100644 index 3bad709..0000000 --- a/examples/tl/attention.py +++ /dev/null @@ -1,21 +0,0 @@ -from engine import LanguageModel -from engine.fx.Proxy import Proxy -from engine import util - -model = LanguageModel('gpt2', device_map='cuda:0') - -gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets." - -with model.generate(max_new_tokens=1, output_attentions=True) as generator: - - with generator.invoke(gpt2_text, output_attentions=True) as invoker: - - tokens = invoker.tokens - - attn_hidden_states = [model.transformer.h[layer_idx].attn.output[2][0].save() for layer_idx in range(len(model.transformer.h))] - -attn_hidden_states = util.apply(attn_hidden_states, lambda x : x.value, Proxy) - -import circuitsvis as cv - -cv.attention.attention_patterns(tokens=tokens, attention=attn_hidden_states[0]) diff --git a/examples/tl/induction.py b/examples/tl/induction.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/tl/patching.py b/examples/tl/patching.py deleted file mode 100644 index 8368ad4..0000000 --- a/examples/tl/patching.py +++ /dev/null @@ -1,72 +0,0 @@ -from engine import Model -from engine.fx.Proxy import Proxy -from engine import util -import torch - -model = Model('gpt2', device_map='cuda:0') - -clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" -corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" - -correct_index = model.tokenizer(" John")['input_ids'][0] -incorrect_index = model.tokenizer(" Mary")['input_ids'][0] - - -with model.generate(max_new_tokens=1) as generator: - - with generator.invoke(clean_prompt) as invoker: - - clean_tokens = invoker.tokens - - clean_hs = [model.transformer.h[layer_idx].output[0] for layer_idx in range(len(model.transformer.h))] - - clean_logits = model.lm_head.output - - clean_logit_diff = (clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]).save() - - with generator.invoke(corrupted_prompt) as invoker: - - corrupted_tokens = invoker.tokens - - corrupted_logits = model.lm_head.output - - corrupted_logit_diff = (corrupted_logits[0, -1, correct_index] - corrupted_logits[0, -1, incorrect_index]).save() - - ioi_patching_results = [] - - for layer_idx in range(len(model.transformer.h)): - - _ioi_patching_results = [] - - for token_idx in range(len(clean_tokens)): - - with generator.invoke(corrupted_prompt) as invoker: - - model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx].t[token_idx] - - patched_logits = model.lm_head.output - - patched_logit_diff = patched_logits[0, -1, correct_index] - patched_logits[0, -1, incorrect_index] - - patched_result = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff) - - _ioi_patching_results.append(patched_result.save()) - - ioi_patching_results.append(_ioi_patching_results) - - -print(f"Clean logit difference: {clean_logit_diff.value:.3f}") -print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}") - -ioi_patching_results = util.apply(ioi_patching_results, lambda x : x.value, Proxy) -ioi_patching_results = util.apply(ioi_patching_results, lambda x : x.item(), torch.Tensor) - -import plotly.express as px - -token_labels = [f"{token}_{index}" for index, token in enumerate(clean_tokens)] - -fig = px.imshow(ioi_patching_results, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":"Position", "y":"Layer"}, x=token_labels, title="Normalized Logit Difference After Patching Residual Stream on the IOI Task") - -fig.write_image("patching.png") - -breakpoint() \ No newline at end of file diff --git a/server/modeling/Config.py b/modeling/Config.py similarity index 100% rename from server/modeling/Config.py rename to modeling/Config.py diff --git a/server/modeling/__init__.py b/modeling/__init__.py similarity index 100% rename from server/modeling/__init__.py rename to modeling/__init__.py diff --git a/server/processors/ModelProcessor.py b/processors/ModelProcessor.py similarity index 100% rename from server/processors/ModelProcessor.py rename to processors/ModelProcessor.py diff --git a/server/processors/Processor.py b/processors/Processor.py similarity index 100% rename from server/processors/Processor.py rename to processors/Processor.py diff --git a/server/processors/RequestProcessor.py b/processors/RequestProcessor.py similarity index 100% rename from server/processors/RequestProcessor.py rename to processors/RequestProcessor.py diff --git a/server/processors/SignalProcessor.py b/processors/SignalProcessor.py similarity index 100% rename from server/processors/SignalProcessor.py rename to processors/SignalProcessor.py diff --git a/server/processors/__init__.py b/processors/__init__.py similarity index 100% rename from server/processors/__init__.py rename to processors/__init__.py diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index fed528d..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1d22d4c..0000000 --- a/setup.cfg +++ /dev/null @@ -1,25 +0,0 @@ -[metadata] -name = engine -version = 0.1.0 -author = Jaden Fiotto-Kaufman -author_email = jadenfk@outlook.com -long_description = file: README.md -long_description_content_type = text/markdown -[options] -include_package_data = True -package_dir = - = engine -install_requires = - transformers - protobuf - python-socketio[client] - tokenizers<0.14 - pydantic - torch - sentencepiece - torchvision - accelerate - diffusers - -[options.package_data] -example = config.yaml \ No newline at end of file diff --git a/server/util.py b/util.py similarity index 100% rename from server/util.py rename to util.py