Skip to content

Commit

Permalink
allow for additional robot internal states to be encoded and attended to
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 19, 2024
1 parent 433fe3a commit fd5238a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
30 changes: 28 additions & 2 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# na - seq of actions
# nt - seq of text tokens
# nv - seq of visual tokens
# ns - seq of additional internal state tokens
# d - dimension
# da - action dimension
# djs - joint state dimension
Expand Down Expand Up @@ -446,6 +447,7 @@ def __init__(
depth = 12,
dim_head = 64,
heads = 8,
dim_internal_state: int | None = None,
use_flex_attn = False,
ff_expand_factor = 4.,
attn_softclamp_value = 50.,
Expand Down Expand Up @@ -487,15 +489,24 @@ def __init__(

self.token_emb = nn.Embedding(num_tokens, dim)

# internal states

self.to_joint_state_tokens = nn.Linear(dim_joint_state, dim)

self.dim_internal_state = default(dim_internal_state, dim)
self.to_internal_state_tokens = nn.Linear(dim_internal_state, dim) if exists(dim_internal_state) else nn.Identity()

# actions

self.dim_action_input = dim_action_input

self.action_register_tokens = nn.Parameter(torch.zeros(num_action_register_tokens, dim))
nn.init.normal_(self.action_register_tokens, std = 0.02)

self.to_action_tokens = nn.Linear(dim_action_input, dim)

# time conditioning

self.to_time_cond = nn.Sequential(
RandomFourierEmbed(dim),
nn.Linear(dim, dim_time_cond),
Expand Down Expand Up @@ -703,10 +714,11 @@ def forward(
self,
images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision
token_ids: Int['b nt'], # language
joint_state: Float['b djs'], # joint state
joint_state: Float['b djs'], # joint state
actions: Float['b na da'] | None = None, # action
times: Float['b'] = None,
reward_tokens: Float['b d'] | None = None,
internal_state_tokens: Float['b ns d'] | None = None,
return_actions_flow = False,
return_state_keys_values = False,
cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None,
Expand Down Expand Up @@ -796,9 +808,23 @@ def forward(
if self.training and random() < self.reward_tokens_dropout_prob:
reward_tokens = reward_tokens[:, 0:0]

# additional internal state tokens

if not exists(internal_state_tokens):
internal_state_tokens = joint_state_tokens.new_empty((batch, 0, self.dim_internal_state))

internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens)

# concat visual rep with language

state_tokens, inverse_packed_states = pack_with_inverse([visual_tokens, language_tokens, joint_state_tokens, reward_tokens], 'b * d')
state_tokens, inverse_packed_states = pack_with_inverse([
visual_tokens,
language_tokens,
joint_state_tokens,
internal_state_tokens,
reward_tokens
], 'b * d')


# take care of masking for variable lengthed states, starting with the language tokens

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.22"
version = "0.0.23"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit fd5238a

Please sign in to comment.