Skip to content

Commit

Permalink
able to specify for each block whether to take input or output
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 19, 2024
1 parent 7c89b22 commit 96c827b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
31 changes: 19 additions & 12 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def forward(self, _, inp, out):
@dataclass
class AugmentParams:
model: Module
hidden_position: Union[Literal['input'], Literal['output']] = 'output'
hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output'
transformer_blocks: Optional[List[Module]] = None
extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
input_shape: Optional[Tuple[int, ...]] = None
Expand All @@ -223,7 +223,7 @@ def __init__(
),
anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_get_hidden_position: Union[Literal['input'], Literal['output']] = 'output',
anchor_hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output',
pad_id: int = -1
):
super().__init__()
Expand Down Expand Up @@ -255,14 +255,21 @@ def __init__(
if not exists(anchor_transformer_blocks):
get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else anchor_extract_blocks_fn
anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm)
anchor_hidden_position = cast_tuple(anchor_hidden_position, len(anchor_transformer_blocks))
assert len(anchor_transformer_blocks) == len(anchor_hidden_position)

for params in augment_llms_params:
if exists(params.transformer_blocks):
continue

extract = default(params.extract_blocks_fn, x_transformer_blocks if isinstance(params.model, TransformerWrapper) else None)

assert exists(extract)

params.transformer_blocks = extract(params.model)
params.hidden_position = cast_tuple(params.hidden_position, len(params.transformer_blocks))

assert len(params.hidden_position) == len(params.transformer_blocks)

# extract all forward outputs from all transformer blocks
# for sanitizing the input (making sure transformer blocks are ordered by execution)
Expand Down Expand Up @@ -320,7 +327,7 @@ def __init__(

anchor_to_augment_outputs = []

for connection, augment_outputs in zip(self.connections, augments_outputs):
for connection, params, augment_outputs in zip(self.connections, augment_llms_params, augments_outputs):

one_num_augment_blocks = len(augment_outputs)

Expand All @@ -329,11 +336,11 @@ def __init__(
assert all([1 <= i <= len(anchor_outputs) for i in anchor_layer_indices]), 'you specified anchor llm layers outside of actual number of layers'
assert all([1 <= i <= len(augment_outputs) for i in augment_layer_indices]), 'you specified augment llm layers outside of actual number of layers'

one_anchor_outputs = [anchor_outputs[i - 1] for i in anchor_layer_indices]
one_augment_outputs = [augment_outputs[i - 1] for i in augment_layer_indices]
one_anchor_outputs_and_positions = [(anchor_outputs[i - 1], anchor_hidden_position[i - 1]) for i in anchor_layer_indices]
one_augment_outputs_and_positions = [(augment_outputs[i - 1], params.hidden_position[i - 1]) for i in augment_layer_indices]

anchor_to_augment_outputs.append(
(one_anchor_outputs, one_augment_outputs)
(one_anchor_outputs_and_positions, one_augment_outputs_and_positions)
)

# function for getting output or input dimension
Expand All @@ -346,28 +353,28 @@ def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[L

# instantiate cross attentions

for (one_anchor_outputs, one_augment_outputs), params in zip(anchor_to_augment_outputs, augment_llms_params):
for (one_anchor_outputs_and_positions, one_augment_outputs_and_positions), params in zip(anchor_to_augment_outputs, augment_llms_params):

augment_llm, augment_position = params.model, params.hidden_position

# number of cross attention for one augmentation llm

num_cross_attns = min(len(one_augment_outputs), len(one_anchor_outputs))
num_cross_attns = min(len(one_augment_outputs_and_positions), len(one_anchor_outputs_and_positions))

# get anchor dims

anchor_dims = [get_hidden_dim(one_anchor_output, anchor_get_hidden_position) for one_anchor_output in one_anchor_outputs]
augment_dims = [get_hidden_dim(one_augment_output, augment_position) for one_augment_output in one_augment_outputs]
anchor_dims = [get_hidden_dim(one_anchor_output, one_anchor_position) for one_anchor_output, one_anchor_position in one_anchor_outputs_and_positions]
augment_dims = [get_hidden_dim(one_augment_output, one_augment_position) for one_augment_output, one_augment_position in one_augment_outputs_and_positions]

# cross attentions for one augmentation llm

recorders = []
one_augment_llm_cross_attns = ModuleList([])

for dim_anchor, dim_augment, _ in zip(anchor_dims, augment_dims, range(num_cross_attns)):
for dim_anchor, dim_augment, (_, anchor_position), (_, augment_position) in zip(anchor_dims, augment_dims, one_anchor_outputs_and_positions, one_augment_outputs_and_positions):
recorder = Recorder(forward_hook_get_hidden = augment_position)
recorders.append(recorder)
one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_get_hidden_position, **attn_kwargs))
one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_position, **attn_kwargs))

# connect the two models

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 96c827b

Please sign in to comment.