Skip to content

Commit

Permalink
memory branch is conditional on hyperparameter
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent cb9ae55 commit d344532
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
23 changes: 12 additions & 11 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,9 @@ def forward(

# take care of maybe recurrent memory tokens

assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0'
assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0'

if not exists(past_recurrent_memory_tokens):
past_recurrent_memory_tokens = actions.new_empty((batch, 0, self.dim))

Expand Down Expand Up @@ -1050,11 +1053,6 @@ def forward(
else:
external_state_tokens = visual_tokens.new_empty((batch, 0, self.dim))

# take care of previous memory tokens

assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0'
assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0'

# concat visual rep with language

state_tokens, inverse_packed_states = pack_with_inverse([
Expand Down Expand Up @@ -1158,11 +1156,13 @@ def forward(

action_tokens = ff_ada_layerscale(action_tokens, time_cond)

memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')
if self.has_recurrent_memories:
memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')

memory_tokens = memories_ff(memory_tokens) + memory_tokens

memory_tokens = memories_ff(memory_tokens) + memory_tokens
memory_tokens = unpack_memory(memory_tokens)

memory_tokens = unpack_memory(memory_tokens)
else:

for (
Expand Down Expand Up @@ -1192,11 +1192,12 @@ def forward(

action_tokens = ff_ada_layerscale(action_tokens, time_cond)

memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')
if self.has_recurrent_memories:
memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')

memory_tokens = memories_ff(memory_tokens) + memory_tokens
memory_tokens = memories_ff(memory_tokens) + memory_tokens

memory_tokens = unpack_memory(memory_tokens)
memory_tokens = unpack_memory(memory_tokens)

if not inferencing:
# unpack and unembed to predictions
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.44"
version = "0.0.45"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit d344532

Please sign in to comment.