From 29068b5480d274e6bb3b7036f16132c91aad3dbb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 2 Dec 2024 06:00:30 -0800 Subject: [PATCH] address https://github.com/lucidrains/pi-zero-pytorch/issues/8 --- pi_zero_pytorch/pi_zero.py | 78 ++++++++++++-------------------------- pyproject.toml | 2 +- 2 files changed, 25 insertions(+), 55 deletions(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index e6cf2b0..ee7e8f6 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -70,13 +70,8 @@ def create_pizero_attn_mask( prefix_causal_length, - mask: Bool['b n'], - internal_state_offset_and_len: tuple[int, int] | None = None + mask: Bool['b n'] ): - - state_offset, state_len = default(internal_state_offset_and_len, (0, 0)) - state_left, state_right = state_offset, state_offset + state_len - # the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side def mask_fn(batch_index, head_index, query_index, key_index): @@ -88,12 +83,7 @@ def mask_fn(batch_index, head_index, query_index, key_index): query_index >= prefix_causal_length ) - bidirectional_internal_state_mask = ( - state_left <= key_index and key_index < state_right and - state_left <= query_index and query_index < state_right - ) - - return (key_mask and causal_mask) or bidirectional_action_mask or bidirectional_internal_state_mask + return (key_mask and causal_mask) or bidirectional_action_mask return mask_fn @@ -359,7 +349,6 @@ def forward( actions, rotary_emb = None, mask: Bool['b n'] | None = None, - internal_state_offset_and_len: tuple[int, int] | None = None, actions_value_residual: Tensor | None = None, return_keys_values = False, flex_attn_fn: Callable | None = None @@ -408,11 +397,6 @@ def forward( causal_mask[..., seq_len:, seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper - if exists(internal_state_offset_and_len): - offset, length = internal_state_offset_and_len - state_slice = slice(offset, offset + length) - causal_mask[..., state_slice, state_slice] = False - sim = sim.masked_fill(causal_mask, max_neg_value(sim)) attn = sim.softmax(dim = -1) @@ -976,13 +960,28 @@ def forward( else: memory_tokens = actions.new_empty((batch, 0, self.dim)) - # pack into [action registers] [actions] [memory tokens (write)] + # joint state + additional internal states - action_tokens, inverse_pack_action_registers = pack_with_inverse([action_register_tokens, action_tokens, memory_tokens], 'b * d') + joint_state_tokens = self.to_joint_state_tokens(joint_state) - action_with_registers_length = action_tokens.shape[-2] + # additional internal state tokens + + if not exists(internal_state_tokens): + internal_state_tokens = joint_state_tokens.new_empty((batch, 0, self.dim_internal_state)) - internal_state_offset_and_len = None + internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens) + + # pack into [action registers] [internal + joint states] [actions] [memory tokens (write)] + + action_tokens, inverse_pack_action_registers = pack_with_inverse([ + action_register_tokens, + joint_state_tokens, + internal_state_tokens, + action_tokens, + memory_tokens + ], 'b * d') + + action_with_registers_length = action_tokens.shape[-2] if not inferencing: # language @@ -1015,10 +1014,6 @@ def forward( visual_tokens = self.maybe_to_image_tokens(visual_tokens) - # joint state - - joint_state_tokens = self.to_joint_state_tokens(joint_state) - # maybe reward tokens if not exists(reward_tokens): @@ -1029,13 +1024,6 @@ def forward( if self.training and random() < self.reward_tokens_dropout_prob: reward_tokens = reward_tokens[:, 0:0] - # additional internal state tokens - - if not exists(internal_state_tokens): - internal_state_tokens = joint_state_tokens.new_empty((batch, 0, self.dim_internal_state)) - - internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens) - # additional external states if exists(external_states): @@ -1050,21 +1038,6 @@ def forward( if not exists(past_recurrent_memory_tokens): past_recurrent_memory_tokens = visual_tokens.new_empty((batch, 0, self.dim)) - # allow joint and internal states to have bidirectional attention - - internal_state_len = joint_state_tokens.shape[-2] + internal_state_tokens.shape[-2] - - internal_state_offset = ( - external_state_tokens.shape[-2] + - visual_tokens.shape[-2] + - language_tokens.shape[-2] - ) - - internal_state_offset_and_len = ( - internal_state_offset, - internal_state_len - ) - # concat visual rep with language state_tokens, inverse_packed_states = pack_with_inverse([ @@ -1072,8 +1045,6 @@ def forward( external_state_tokens, visual_tokens, language_tokens, - joint_state_tokens, - internal_state_tokens, reward_tokens ], 'b * d') @@ -1094,7 +1065,7 @@ def forward( # rotary embeddings - seq = torch.cumsum(mask.float(), dim = -1) + seq = mask.float().cumsum(dim = -1) rotary_emb = self.rotary_emb(seq) rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d') @@ -1112,7 +1083,6 @@ def forward( create_pizero_attn_mask( prefix_length, mask = mask, - internal_state_offset_and_len = internal_state_offset_and_len ), Q_LEN = seq_len, KV_LEN = seq_len, @@ -1147,7 +1117,7 @@ def forward( action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) - (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, internal_state_offset_and_len = internal_state_offset_and_len, return_keys_values = True) + (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, return_keys_values = True) state_cached_keys_values.append((state_keys, state_values)) @@ -1200,7 +1170,7 @@ def forward( tokens = self.final_norm_softclamp(tokens) - action_register_tokens, action_tokens, written_memory_tokens = inverse_pack_action_registers(action_tokens) + *_, action_tokens, written_memory_tokens = inverse_pack_action_registers(action_tokens) action_tokens = self.final_norm_softclamp(action_tokens) diff --git a/pyproject.toml b/pyproject.toml index 95d0ca5..f9faa15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.34" +version = "0.0.36" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }