diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index b89687c..16ee241 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -53,16 +53,17 @@ from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_attention = torch.compile(flex_attention) -def create_pizero_attn_mask(prefix_causal_length): +def create_pizero_attn_mask(prefix_causal_length, mask: Bool['b n']): # the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side - def inner(batch_index, head_index, query_index, key_index): + def mask_fn(batch_index, head_index, query_index, key_index): return ( + mask[batch_index, key_index] and # variable length states query_index >= key_index and # causal key_index >= prefix_causal_length # bidirectional ) - return inner + return mask_fn def softclamp_score_mod(value): def identity(score, b, h, q, k): @@ -84,12 +85,17 @@ def exists(v): def default(v, d): return v if exists(v) else d +# tensor helpers + def softclamp(t, value): if value <= 0.: return t return (t / value).tanh() * value +def max_neg_value(t): + return -torch.finfo(t.dtype).max + def pack_with_inverse(t, pattern): packed, packed_shape = pack(t, pattern) @@ -142,6 +148,7 @@ def forward_actions_with_cached_state( actions, cached_state_keys_values: tuple[Tensor, Tensor], rotary_emb = None, + mask: Bool['b n'] | None = None, actions_value_residual: Tensor | None = None, return_keys_values = False, flex_attn_fn: Callable | None = None @@ -159,7 +166,7 @@ def forward_actions_with_cached_state( k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av))) if exists(rotary_emb): - q = apply_rotary_emb(rotary_emb, q) + q = apply_rotary_emb(rotary_emb, q, freqs_seq_dim = -2) k = apply_rotary_emb(rotary_emb, k) elif exists(self.rotary_emb): @@ -176,6 +183,9 @@ def forward_actions_with_cached_state( sim = softclamp(sim, self.softclamp_value) + if exists(mask): + sim = einx.where('b j, b h i j, -> b h i j', mask, sim, max_neg_value(sim)) + attn = sim.softmax(dim = -1) out = einsum(attn, v, 'b h i j, b h j d -> b h i d') @@ -196,6 +206,7 @@ def forward( multimodal_seq, actions, rotary_emb = None, + mask: Bool['b n'] | None = None, actions_value_residual: Tensor | None = None, return_keys_values = False, flex_attn_fn: Callable | None = None @@ -238,9 +249,12 @@ def forward( causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1) + if exists(mask): + causal_mask = einx.logical_or('b j, i j -> b 1 i j', ~mask, causal_mask) + causal_mask[..., seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + sim = sim.masked_fill(causal_mask, max_neg_value(sim)) attn = sim.softmax(dim = -1) @@ -320,6 +334,10 @@ def __init__( self.to_beta = LinearNoBias(dim_cond, dim) def forward(self, actions, cond): + + if cond.ndim == 2: + cond = rearrange(cond, 'b d -> b 1 d') + normed = self.norm(actions) gamma = self.to_gamma(cond) beta = self.to_beta(cond) @@ -340,6 +358,10 @@ def __init__( self.to_adaln_zero_gamma = adaln_zero_gamma_linear def forward(self, actions, cond): + + if cond.ndim == 2: + cond = rearrange(cond, 'b d -> b 1 d') + gamma = self.to_adaln_zero_gamma(cond) return actions * gamma.sigmoid() @@ -366,6 +388,7 @@ def __init__( num_action_register_tokens = 4, attn_kwargs: dict = dict(), ff_kwargs: dict = dict(), + lm_pad_id = -1, lm_loss_weight = 1., flow_loss_weight = 1., direction_loss_weight = 0., @@ -447,6 +470,10 @@ def __init__( self.state_to_logits = LinearNoBias(dim, num_tokens) self.actions_to_pred_flow = LinearNoBias(dim, dim_action_input) + # the language token id padding id, for fine-tuning as well as taking care of the masking on top of causal mask + + self.lm_pad_id = lm_pad_id + # loss related self.lm_loss_weight = lm_loss_weight @@ -465,6 +492,13 @@ def __init__( def device(self): return next(self.parameters()).device + @beartype + def pretrained_vlm_weights( + self, + weights: dict[str, Tensor] + ): + raise NotImplementedError + @torch.inference_mode() def sample( self, @@ -474,9 +508,10 @@ def sample( trajectory_length: int, reward_tokens = None, steps = 18, - batch_size = 1, show_pbar = True ): + batch_size = token_ids.shape[0] + was_training = self.training self.eval() @@ -562,7 +597,7 @@ def forward( flow = actions - noise padded_times = rearrange(times, 'b -> b 1 1') - actions = noise * (1. - padded_times) + padded_times * actions + actions = noise.lerp(actions, padded_times) # actions @@ -617,34 +652,13 @@ def forward( state_tokens, inverse_packed_states = pack_with_inverse([visual_tokens, language_tokens, joint_state_tokens, reward_tokens], 'b * d') - # prepare maybe flex attention - - flex_attn_fn = None - - if self.use_flex_attn and state_tokens.is_cuda: - - block_mask = None - - if not inferencing: - prefix_length = state_tokens.shape[-2] - seq_len = prefix_length + action_tokens.shape[-2] - - block_mask = create_block_mask( - create_pizero_attn_mask(prefix_length), - Q_LEN = seq_len, - KV_LEN = seq_len, - device = state_tokens.device - ) + # take care of masking for variable lengthed states, starting with the language tokens - score_mod_fn = softclamp_score_mod(self.attn_softclamp_value) + # which then leads to proper rotary embeddings - flex_attn_fn = partial( - flex_attention, - block_mask = block_mask, - score_mod = score_mod - ) + command_length = token_ids.shape[-1] - # prepare rotary embeddings + language_mask = token_ids != self.lm_pad_id action_with_registers_length = action_tokens.shape[-2] @@ -654,12 +668,39 @@ def forward( state_length = state_tokens.shape[-2] total_seq_length = action_with_registers_length + state_length + mask = F.pad(language_mask, (state_length - command_length - 1, 1 + action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later # rotary embeddings - seq = torch.arange(total_seq_length, device = self.device) + seq = torch.cumsum(mask.float(), dim = -1) rotary_emb = self.rotary_emb(seq) + rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d') + + # prepare maybe flex attention + + flex_attn_fn = None + + if not inferencing and self.use_flex_attn and state_tokens.is_cuda: + + prefix_length = state_tokens.shape[-2] + seq_len = prefix_length + action_tokens.shape[-2] + + block_mask = create_block_mask( + create_pizero_attn_mask(prefix_length, mask = mask), + Q_LEN = seq_len, + KV_LEN = seq_len, + device = state_tokens.device + ) + + score_mod_fn = softclamp_score_mod(self.attn_softclamp_value) + + flex_attn_fn = partial( + flex_attention, + block_mask = block_mask, + score_mod = score_mod + ) + # state keys and values for caching during inference cached_state_key_values_iter = iter(default(cached_state_keys_values, [])) @@ -680,7 +721,7 @@ def forward( action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) - (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True) + (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, return_keys_values = True) state_cached_keys_values.append((state_keys, state_values)) @@ -708,7 +749,7 @@ def forward( action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) - actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), rotary_emb = rotary_emb, return_keys_values = True) + actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), rotary_emb = rotary_emb, mask = mask, return_keys_values = True) state_cached_keys_values.append((state_keys, state_values)) @@ -768,7 +809,8 @@ def forward( language_loss = F.cross_entropy( rearrange(language_logits[:, :-1], 'b n l -> b l n'), - labels + labels, + ignore_index = self.lm_pad_id ) # loss breakdown diff --git a/pyproject.toml b/pyproject.toml index d6ca494..86fb6b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.12" +version = "0.0.14" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }