-
Notifications
You must be signed in to change notification settings - Fork 54
/
transparent_llm.py
199 lines (169 loc) · 5.53 KB
/
transparent_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
import torch
from jaxtyping import Float, Int
@dataclass
class ModelInfo:
name: str
# Not the actual number of parameters, but rather the order of magnitude
n_params_estimate: int
n_layers: int
n_heads: int
d_model: int
d_vocab: int
class TransparentLlm(ABC):
"""
An abstract stateful interface for a language model. The model is supposed to be
loaded at the class initialization.
The internal state is the resulting tensors from the last call of the `run` method.
Most of the methods could return values based on the state, but some may do cheap
computations based on them.
"""
@abstractmethod
def model_info(self) -> ModelInfo:
"""
Gives general info about the model. This method must be available before any
calls of the `run`.
"""
pass
@abstractmethod
def run(self, sentences: List[str]) -> None:
"""
Run the inference on the given sentences in a single batch and store all
necessary info in the internal state.
"""
pass
@abstractmethod
def batch_size(self) -> int:
"""
The size of the batch that was used for the last call of `run`.
"""
pass
@abstractmethod
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
pass
@abstractmethod
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
pass
@abstractmethod
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
pass
@abstractmethod
def unembed(
self,
t: Float[torch.Tensor, "d_model"],
normalize: bool,
) -> Float[torch.Tensor, "vocab"]:
"""
Project the given vector (for example, the state of the residual stream for a
layer and token) into the output vocabulary.
normalize: whether to apply the final normalization before the unembedding.
Setting it to True and applying to output of the last layer gives the output of
the model.
"""
pass
# ================= Methods related to the residual stream =================
@abstractmethod
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream before entering the layer. For example, when
layer == 0 these must the embedded tokens (including positional embedding).
"""
pass
@abstractmethod
def residual_after_attn(
self, layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after attention, but before the FFN in the
given layer.
"""
pass
@abstractmethod
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after the given layer. This is equivalent to the
next layer's input.
"""
pass
# ================ Methods related to the feed-forward layer ===============
@abstractmethod
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The output of the FFN layer, before it gets merged into the residual stream.
"""
pass
@abstractmethod
def decomposed_ffn_out(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden d_model"]:
"""
A collection of vectors added to the residual stream by each neuron. It should
be the same as neuron activations multiplied by neuron outputs.
"""
pass
@abstractmethod
def neuron_activations(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "d_ffn"]:
"""
The content of the hidden layer right after the activation function was applied.
"""
pass
@abstractmethod
def neuron_output(
self,
layer: int,
neuron: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return the value that the given neuron adds to the residual stream. It's a raw
vector from the model parameters, no activation involved.
"""
pass
# ==================== Methods related to the attention ====================
@abstractmethod
def attention_matrix(
self, batch_i, layer: int, head: int
) -> Float[torch.Tensor, "query_pos key_pos"]:
"""
Return a lower-diagonal attention matrix.
"""
pass
@abstractmethod
def attention_output(
self,
batch_i: int,
layer: int,
pos: int,
head: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return what the given head at the given layer and pos added to the residual
stream.
"""
pass
@abstractmethod
def decomposed_attn(
self, batch_i: int, layer: int
) -> Float[torch.Tensor, "source target head d_model"]:
"""
Here
- source: index of token from the previous layer
- target: index of token on the current layer
The decomposed attention tells what vector from source representation was used
in order to contribute to the taget representation.
"""
pass