Skip to content

Commit

Permalink
allow for one to customize the min reward for the conservative reg lo…
Browse files Browse the repository at this point in the history
…ss, even though in paper they set this to 0. also allow for one to pass in the monte carlo return (not even sure what this is, but following the pseudocode in the paper)
  • Loading branch information
lucidrains committed Nov 28, 2023
1 parent 1e263e3 commit 5013ee7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
45 changes: 37 additions & 8 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def q_learn(
actions: TensorType['b', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
reward: TensorType['b', float],
done: TensorType['b', bool]
done: TensorType['b', bool],
*,
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
# 'next' stands for the very next time step (whether state, q, actions etc)

Expand All @@ -259,6 +262,8 @@ def q_learn(

q_next = self.ema_model(next_states, instructions).amax(dim = -1)

q_next = q_next.clamp(min = default(monte_carlo_return, 1e4))

# Bellman's equation. most important line of code, hopefully done correctly

q_target = reward + not_terminal * (γ * q_next)
Expand All @@ -279,7 +284,10 @@ def n_step_q_learn(
actions: TensorType['b', 't', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool]
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
"""
einops
Expand Down Expand Up @@ -323,6 +331,8 @@ def n_step_q_learn(

q_next = self.ema_model(next_states, instructions).amax(dim = -1)

q_next = q_next.clamp(min = default(monte_carlo_return, 1e4))

# prepare rewards and discount factors across timesteps

rewards, _ = pack([rewards, q_next], 'b *')
Expand All @@ -345,16 +355,26 @@ def n_step_q_learn(

def learn(
self,
*args
*args,
min_reward: float = 0.,
monte_carlo_return: Optional[float] = None
):
_, _, actions, *_ = args

# q-learn kwargs

q_learn_kwargs = dict(
monte_carlo_return = monte_carlo_return
)

# main q-learning loss, whether single or n-step

if self.n_step_q_learning:
td_loss, q_intermediates = self.n_step_q_learn(*args)
td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs)
num_timesteps = actions.shape[1]
else:
td_loss, q_intermediates = self.q_learn(*args)
td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs)
num_timesteps = 1

# calculate conservative regularization
# section 4.2 in paper, eq 2
Expand All @@ -372,7 +392,7 @@ def learn(
q_actions_not_taken = q_preds[~dataset_action_mask.bool()]
q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions)

conservative_reg_loss = (q_actions_not_taken ** 2).sum() / num_non_dataset_actions
conservative_reg_loss = ((q_actions_not_taken - (min_reward * num_timesteps)) ** 2).sum() / num_non_dataset_actions

# total loss

Expand All @@ -383,7 +403,12 @@ def learn(

return loss, loss_breakdown

def forward(self):
def forward(
self,
*,
monte_carlo_return: Optional[float] = None,
min_reward: float = 0.
):
step = self.step.item()

replay_buffer_iter = cycle(self.dataloader)
Expand All @@ -401,7 +426,11 @@ def forward(self):

with self.accelerator.autocast():

loss, (td_loss, conservative_reg_loss) = self.learn(*next(replay_buffer_iter))
loss, (td_loss, conservative_reg_loss) = self.learn(
*next(replay_buffer_iter),
min_reward = min_reward,
monte_carlo_return = monte_carlo_return
)

self.accelerator.backward(loss)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 5013ee7

Please sign in to comment.