diff --git a/CALM_pytorch/CALM.py b/CALM_pytorch/CALM.py index 100ccd8..50c1f28 100644 --- a/CALM_pytorch/CALM.py +++ b/CALM_pytorch/CALM.py @@ -201,7 +201,7 @@ def forward(self, _, inp, out): @dataclass class AugmentParams: model: Module - hidden_position: Union[Literal['input'], Literal['output']] = 'output' + hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output' transformer_blocks: Optional[List[Module]] = None extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None input_shape: Optional[Tuple[int, ...]] = None @@ -223,7 +223,7 @@ def __init__( ), anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None, anchor_transformer_blocks: Optional[List[Module]] = None, - anchor_get_hidden_position: Union[Literal['input'], Literal['output']] = 'output', + anchor_hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output', pad_id: int = -1 ): super().__init__() @@ -255,14 +255,21 @@ def __init__( if not exists(anchor_transformer_blocks): get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else anchor_extract_blocks_fn anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm) + anchor_hidden_position = cast_tuple(anchor_hidden_position, len(anchor_transformer_blocks)) + assert len(anchor_transformer_blocks) == len(anchor_hidden_position) for params in augment_llms_params: if exists(params.transformer_blocks): continue extract = default(params.extract_blocks_fn, x_transformer_blocks if isinstance(params.model, TransformerWrapper) else None) + assert exists(extract) + params.transformer_blocks = extract(params.model) + params.hidden_position = cast_tuple(params.hidden_position, len(params.transformer_blocks)) + + assert len(params.hidden_position) == len(params.transformer_blocks) # extract all forward outputs from all transformer blocks # for sanitizing the input (making sure transformer blocks are ordered by execution) @@ -320,7 +327,7 @@ def __init__( anchor_to_augment_outputs = [] - for connection, augment_outputs in zip(self.connections, augments_outputs): + for connection, params, augment_outputs in zip(self.connections, augment_llms_params, augments_outputs): one_num_augment_blocks = len(augment_outputs) @@ -329,11 +336,11 @@ def __init__( assert all([1 <= i <= len(anchor_outputs) for i in anchor_layer_indices]), 'you specified anchor llm layers outside of actual number of layers' assert all([1 <= i <= len(augment_outputs) for i in augment_layer_indices]), 'you specified augment llm layers outside of actual number of layers' - one_anchor_outputs = [anchor_outputs[i - 1] for i in anchor_layer_indices] - one_augment_outputs = [augment_outputs[i - 1] for i in augment_layer_indices] + one_anchor_outputs_and_positions = [(anchor_outputs[i - 1], anchor_hidden_position[i - 1]) for i in anchor_layer_indices] + one_augment_outputs_and_positions = [(augment_outputs[i - 1], params.hidden_position[i - 1]) for i in augment_layer_indices] anchor_to_augment_outputs.append( - (one_anchor_outputs, one_augment_outputs) + (one_anchor_outputs_and_positions, one_augment_outputs_and_positions) ) # function for getting output or input dimension @@ -346,28 +353,28 @@ def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[L # instantiate cross attentions - for (one_anchor_outputs, one_augment_outputs), params in zip(anchor_to_augment_outputs, augment_llms_params): + for (one_anchor_outputs_and_positions, one_augment_outputs_and_positions), params in zip(anchor_to_augment_outputs, augment_llms_params): augment_llm, augment_position = params.model, params.hidden_position # number of cross attention for one augmentation llm - num_cross_attns = min(len(one_augment_outputs), len(one_anchor_outputs)) + num_cross_attns = min(len(one_augment_outputs_and_positions), len(one_anchor_outputs_and_positions)) # get anchor dims - anchor_dims = [get_hidden_dim(one_anchor_output, anchor_get_hidden_position) for one_anchor_output in one_anchor_outputs] - augment_dims = [get_hidden_dim(one_augment_output, augment_position) for one_augment_output in one_augment_outputs] + anchor_dims = [get_hidden_dim(one_anchor_output, one_anchor_position) for one_anchor_output, one_anchor_position in one_anchor_outputs_and_positions] + augment_dims = [get_hidden_dim(one_augment_output, one_augment_position) for one_augment_output, one_augment_position in one_augment_outputs_and_positions] # cross attentions for one augmentation llm recorders = [] one_augment_llm_cross_attns = ModuleList([]) - for dim_anchor, dim_augment, _ in zip(anchor_dims, augment_dims, range(num_cross_attns)): + for dim_anchor, dim_augment, (_, anchor_position), (_, augment_position) in zip(anchor_dims, augment_dims, one_anchor_outputs_and_positions, one_augment_outputs_and_positions): recorder = Recorder(forward_hook_get_hidden = augment_position) recorders.append(recorder) - one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_get_hidden_position, **attn_kwargs)) + one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_position, **attn_kwargs)) # connect the two models diff --git a/setup.py b/setup.py index b61830e..ca2ef92 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'CALM-Pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'CALM - Pytorch', author = 'Phil Wang',