Skip to content

Commit

Permalink
make sure writeable memories can be returned and passed back in
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent 6716ce2 commit a45fbf8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ def forward(
reward_tokens: Float['b d'] | None = None,
internal_state_tokens: Float['b ns d'] | None = None,
external_states: tuple[Float['b ...']] | None = None,
record_and_return_memory_tokens = False,
past_recurrent_memory_tokens: Float['b {self._nm} d'] | None = None,
return_actions_flow = False,
return_state_keys_values = False,
Expand Down Expand Up @@ -1039,6 +1040,9 @@ def forward(

# take care of previous memory tokens

assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0'
assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0'

if not exists(past_recurrent_memory_tokens):
past_recurrent_memory_tokens = visual_tokens.new_empty((batch, 0, self.dim))

Expand Down Expand Up @@ -1187,9 +1191,13 @@ def forward(
pred_actions_flow = self.actions_to_pred_flow(actions)

if return_actions_flow:
if not return_state_keys_values:

if not return_state_keys_values and not record_and_return_memory_tokens:
return pred_actions_flow

if not return_state_keys_values:
return pred_actions_flow, written_memory_tokens

return pred_actions_flow, state_cached_keys_values

flow_loss = self.zero
Expand Down Expand Up @@ -1223,7 +1231,10 @@ def forward(
flow_loss * self.flow_loss_weight
)

return total_loss, loss_breakdown
if not record_and_return_memory_tokens:
return total_loss, loss_breakdown

return total_loss, loss_breakdown, written_memory_tokens

# fun

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.41"
version = "0.0.42"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit a45fbf8

Please sign in to comment.