diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 69f993d..ec61804 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -53,9 +53,11 @@ # vision and language tokens are autoregressive causal mask, actions, interal states + joint bidirectional amongst own tokens, but still autoregressive with respect to other tokens # [state token groups] [action token groups] -> [autoregressive masking] [bidirectional] - # [external state] [visual tokens] [language tokens] [maybe reward / condition token] [action registers] [joint state + internal state] [actions] +# for an attempt to introduce recurrence, all tokens above can be flanked by read and write memory tokens +# [read memory tokens] [...] [write memory tokens] + # constants LinearNoBias = partial(nn.Linear, bias = False)