diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 5a404a6..7982dc3 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -420,6 +420,7 @@ def sample( token_ids, joint_states, trajectory_length: int, + reward_tokens = None, steps = 18, batch_size = 1, show_pbar = True @@ -442,6 +443,7 @@ def ode_fn(timestep, denoised_actions): joint_states, denoised_actions, times = timestep, + reward_tokens = reward_tokens, cached_state_keys_values = cached_state_kv, return_actions_flow = True, return_state_keys_values = True diff --git a/pyproject.toml b/pyproject.toml index eece4aa..658413c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.6" +version = "0.0.8" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }