Skip to content

Commit

Permalink
one more tiny step
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 21, 2024
1 parent 5bd82b7 commit 0df9dbb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
19 changes: 16 additions & 3 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def __call__(self, *args):
hidden = self.get_output_fn(*args)
self.output = hidden.detach()

def pop_saved(self):
def release_saved_(self):
self.output = None

def pop_saved_(self):
output = self.output
assert exists(output)
self.output = None
self.release_saved_()
return output

# cross attention wrapper class
Expand Down Expand Up @@ -204,7 +207,7 @@ def forward(self, *hook_args):
x = get_block_output_from_hook_outputs(self.forward_hook_get_hidden, *hook_args)

if self.release_recorder_output:
context = self.recorder.pop_saved()
context = self.recorder.pop_saved_()
else:
context = self.recorder.output

Expand Down Expand Up @@ -441,6 +444,16 @@ def load_state_dict(self, pkg, strict = False):
def parameters(self):
return self.cross_attns.parameters()

def set_release_recorder_output(self, release_recorder_output: bool):
for module in self.modules():
if isinstance(module, CrossAttentionBlock):
module.release_recorder_output = release_recorder_output

def clear_recorded_augment_hiddens(self):
for module in self.modules():
if isinstance(module, CrossAttentionBlock):
module.recorder.release_saved_()

@beartype
def forward(
self,
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ loss.backward()

- [ ] take care of caching the augment hiddens when sampling. forget about anchor kv cache for now
- [x] logic for not releasing the saved output from recorder, for inference
- [ ] use a contextmanager for managing cross attention block state for popping the saved output from the recorder
- [x] managing cross attention block state for popping the saved output from the recorder
- [ ] move the augmentation forwards into one shared method, and craft out sampling method for anchor

- [ ] show an example with giving the LLM ability to hear as well, using <a href="https://github.com/lucidrains/audiolm-pytorch">hubert or wav2vec</a> wrappers
- [ ] handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.1.6',
version = '0.1.7',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0df9dbb

Please sign in to comment.