Skip to content

Commit

Permalink
complete join attention of memories at park
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent 4b08c06 commit 97ed762
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
self.accept_memories = accept_memories

self.mem_rmsnorm = nn.RMSNorm(dim) if accept_memories else None
self.to_mem_qkv = LinearNoBias(dim, 3.* dim_inner) if accept_memories else None
self.to_mem_qkv = LinearNoBias(dim, 3 * dim_inner) if accept_memories else None
self.to_mem_out = LinearNoBias(dim_inner, dim) if accept_memories else None

# action parameters
Expand Down Expand Up @@ -391,11 +391,17 @@ def forward(
assert not (self.accept_memories ^ exists(memories))

if exists(memories):
memories = self.mem_rmsnorm(memories)
memories, unpack_memories = pack_with_inverse(memories, 'b * d')
memories = self.mem_rmsnorm(memories)
mqkv = self.to_mem_qkv(memories)
mqkv_read, mqkv_write = unpack_memories(mqkv, 'b * d')

mqr, mkr, mvr, mqw, mkw, mvw = tuple(self.split_heads(t) for t in (*mqkv_read.chunk(3, dim = -1), *mqkv_write.chunk(3, dim = -1)))

k = torch.cat((mkr, k, mkw), dim = -2)
v = torch.cat((mvr, v, mvw), dim = -2)
q, attn_output_unpack_memories = pack_with_inverse((mqr, q, mqw), 'b h * d')

# rotary embedding

if exists(rotary_emb):
Expand Down Expand Up @@ -436,6 +442,11 @@ def forward(

out = out * gates

# split out memories

if self.accept_memories:
mem_read_out, out, mem_write_out = attn_output_unpack_memories(out)

# merge attention heads

out = self.merge_heads(out)
Expand Down Expand Up @@ -651,7 +662,7 @@ def __init__(
is_first_block = i == 0

layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
Attention(dim = dim, dim_head = dim_head, heads = heads, accept_memories = self.has_recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None
Expand Down Expand Up @@ -1013,6 +1024,8 @@ def forward(

memory_tokens = (past_recurrent_memory_tokens, write_memory_tokens)

mem_length = past_recurrent_memory_tokens.shape[-2] + write_memory_tokens.shape[-2]

# pack into [action registers] [internal + joint states] [actions]

action_tokens, inverse_pack_action_registers = pack_with_inverse([
Expand Down Expand Up @@ -1098,6 +1111,10 @@ def forward(

mask = F.pad(language_mask, (state_length - command_length, action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later

# memory

mask = F.pad(mask, (past_recurrent_memory_tokens.shape[-2], write_memory_tokens.shape[-2]), value = True)

# rotary embeddings

seq = mask.float().cumsum(dim = -1)
Expand Down Expand Up @@ -1159,7 +1176,8 @@ def forward(
flex_attn_fn = flex_attn_fn,
actions_value_residual = actions_value_residual,
mask = mask,
return_keys_values = True
return_keys_values = True,
memories = memory_tokens
)

state_cached_keys_values.append((state_keys, state_values))
Expand Down

0 comments on commit 97ed762

Please sign in to comment.