Skip to content

Commit

Permalink
final norm for the writeable memories
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent a45fbf8 commit d35d035
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,8 @@ def __init__(
self.memory_tokens = nn.Parameter(torch.zeros(num_recurrent_memory_tokens, dim))
nn.init.normal_(self.memory_tokens, std = 0.02)

self.final_norm_write_memories = nn.RMSNorm(dim) if self.has_recurrent_memories else None

# attention and feedforward

layers = []
Expand All @@ -627,7 +629,7 @@ def __init__(
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs)
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
]))

cond_layers.append(ModuleList([
Expand Down Expand Up @@ -1178,7 +1180,12 @@ def forward(

action_tokens = self.final_norm_softclamp(action_tokens)

# projection
# writeable memories norm

if self.has_recurrent_memories:
written_memory_tokens = self.final_norm_write_memories(written_memory_tokens)

# final actions norm

actions = self.final_actions_norm(action_tokens)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.42"
version = "0.0.43"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit d35d035

Please sign in to comment.