diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index ec61804..bd91e39 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -966,6 +966,9 @@ def forward( # take care of maybe recurrent memory tokens + assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0' + assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0' + if not exists(past_recurrent_memory_tokens): past_recurrent_memory_tokens = actions.new_empty((batch, 0, self.dim)) @@ -1050,11 +1053,6 @@ def forward( else: external_state_tokens = visual_tokens.new_empty((batch, 0, self.dim)) - # take care of previous memory tokens - - assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0' - assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0' - # concat visual rep with language state_tokens, inverse_packed_states = pack_with_inverse([ @@ -1158,11 +1156,13 @@ def forward( action_tokens = ff_ada_layerscale(action_tokens, time_cond) - memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') + if self.has_recurrent_memories: + memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') + + memory_tokens = memories_ff(memory_tokens) + memory_tokens - memory_tokens = memories_ff(memory_tokens) + memory_tokens + memory_tokens = unpack_memory(memory_tokens) - memory_tokens = unpack_memory(memory_tokens) else: for ( @@ -1192,11 +1192,12 @@ def forward( action_tokens = ff_ada_layerscale(action_tokens, time_cond) - memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') + if self.has_recurrent_memories: + memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') - memory_tokens = memories_ff(memory_tokens) + memory_tokens + memory_tokens = memories_ff(memory_tokens) + memory_tokens - memory_tokens = unpack_memory(memory_tokens) + memory_tokens = unpack_memory(memory_tokens) if not inferencing: # unpack and unembed to predictions diff --git a/pyproject.toml b/pyproject.toml index 4f220d4..30d32a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.44" +version = "0.0.45" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }