diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index 2337a1b..16dcf61 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -241,7 +241,10 @@ def q_learn( actions: TensorType['b', int], next_states: TensorType['b', 'c', 'f', 'h', 'w', float], reward: TensorType['b', float], - done: TensorType['b', bool] + done: TensorType['b', bool], + *, + monte_carlo_return = None + ) -> Tuple[Tensor, QIntermediates]: # 'next' stands for the very next time step (whether state, q, actions etc) @@ -259,6 +262,8 @@ def q_learn( q_next = self.ema_model(next_states, instructions).amax(dim = -1) + q_next = q_next.clamp(min = default(monte_carlo_return, 1e4)) + # Bellman's equation. most important line of code, hopefully done correctly q_target = reward + not_terminal * (γ * q_next) @@ -279,7 +284,10 @@ def n_step_q_learn( actions: TensorType['b', 't', int], next_states: TensorType['b', 'c', 'f', 'h', 'w', float], rewards: TensorType['b', 't', float], - dones: TensorType['b', 't', bool] + dones: TensorType['b', 't', bool], + *, + monte_carlo_return = None + ) -> Tuple[Tensor, QIntermediates]: """ einops @@ -323,6 +331,8 @@ def n_step_q_learn( q_next = self.ema_model(next_states, instructions).amax(dim = -1) + q_next = q_next.clamp(min = default(monte_carlo_return, 1e4)) + # prepare rewards and discount factors across timesteps rewards, _ = pack([rewards, q_next], 'b *') @@ -345,16 +355,26 @@ def n_step_q_learn( def learn( self, - *args + *args, + min_reward: float = 0., + monte_carlo_return: Optional[float] = None ): _, _, actions, *_ = args + # q-learn kwargs + + q_learn_kwargs = dict( + monte_carlo_return = monte_carlo_return + ) + # main q-learning loss, whether single or n-step if self.n_step_q_learning: - td_loss, q_intermediates = self.n_step_q_learn(*args) + td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs) + num_timesteps = actions.shape[1] else: - td_loss, q_intermediates = self.q_learn(*args) + td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs) + num_timesteps = 1 # calculate conservative regularization # section 4.2 in paper, eq 2 @@ -372,7 +392,7 @@ def learn( q_actions_not_taken = q_preds[~dataset_action_mask.bool()] q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions) - conservative_reg_loss = (q_actions_not_taken ** 2).sum() / num_non_dataset_actions + conservative_reg_loss = ((q_actions_not_taken - (min_reward * num_timesteps)) ** 2).sum() / num_non_dataset_actions # total loss @@ -383,7 +403,12 @@ def learn( return loss, loss_breakdown - def forward(self): + def forward( + self, + *, + monte_carlo_return: Optional[float] = None, + min_reward: float = 0. + ): step = self.step.item() replay_buffer_iter = cycle(self.dataloader) @@ -401,7 +426,11 @@ def forward(self): with self.accelerator.autocast(): - loss, (td_loss, conservative_reg_loss) = self.learn(*next(replay_buffer_iter)) + loss, (td_loss, conservative_reg_loss) = self.learn( + *next(replay_buffer_iter), + min_reward = min_reward, + monte_carlo_return = monte_carlo_return + ) self.accelerator.backward(loss) diff --git a/setup.py b/setup.py index c47a0d9..98db7ae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',