diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 16ee241..a0b32ed 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -29,6 +29,7 @@ # b - batch # n - sequence +# i, j - attention row and col sequence # na - seq of actions # nt - seq of text tokens # nv - seq of visual tokens