Skip to content

Commit

Permalink
complete the conservative regularization loss, align some hyperparame…
Browse files Browse the repository at this point in the history
…ters
  • Loading branch information
lucidrains committed Nov 28, 2023
1 parent 76d9c07 commit e2772a2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S
- [x] offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3
- [x] add optional deep dueling architecture
- [x] add n-step Q learning
- [x] build the conservative regularization

- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)
- [ ] figure out the conservative regularization, read prior work
- [ ] add double Q + pessimism support
- [ ] improvise a cross attention variant instead of concatenating previous actions? (could have wrong intuition here)
- [ ] see if the main idea in this paper is applicable to language models <a href="https://github.com/lucidrains/llama-qrlhf">here</a>
Expand Down
72 changes: 61 additions & 11 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@
# constants

QIntermediates = namedtuple('QIntermediates', [
'q_pred_all_actions',
'q_pred',
'q_next',
'q_target'
])

Losses = namedtuple('Losses', [
'td_loss',
'conservative_reg_loss'
])

# helpers

def exists(val):
Expand Down Expand Up @@ -89,12 +95,13 @@ def __init__(
shuffle = True
),
q_target_ema_kwargs: dict = dict(
beta = 0.999,
beta = 0.99,
update_after_step = 10,
update_every = 5
),
n_step_q_learning = False,
discount_factor_gamma = 0.99,
discount_factor_gamma = 0.98,
conservative_reg_loss_weight = 1., # they claim 1. is best in paper
optimizer_kwargs: dict = dict(),
checkpoint_folder = './checkpoints',
checkpoint_every = 1000,
Expand All @@ -106,6 +113,7 @@ def __init__(

self.discount_factor_gamma = discount_factor_gamma
self.n_step_q_learning = n_step_q_learning
self.conservative_reg_loss_weight = conservative_reg_loss_weight

self.register_buffer('discount_matrix', None, persistent = False)

Expand Down Expand Up @@ -243,7 +251,8 @@ def q_learn(
# first make a prediction with online q robotic transformer
# select out the q-values for the action that was taken

q_pred = batch_select_indices(self.model(states, instructions), actions)
q_pred_all_actions = self.model(states, instructions)
q_pred = batch_select_indices(q_pred_all_actions, actions)

# use an exponentially smoothed copy of model for the future q target. more stable than setting q_target to q_eval after each batch
# the max Q value is taken as the optimal action is implicitly the one with the highest Q score
Expand Down Expand Up @@ -281,6 +290,7 @@ def n_step_q_learn(
h - height
w - width
t - timesteps
a - action bins
q - q values
"""

Expand All @@ -307,7 +317,8 @@ def n_step_q_learn(

actions = rearrange(actions, 'b t -> (b t)')

q_pred = batch_select_indices(self.model(states, repeated_instructions), actions)
q_pred_all_actions = self.model(states, repeated_instructions)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_pred = unpack_one(q_pred, time_ps, '*')

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
Expand All @@ -326,7 +337,50 @@ def n_step_q_learn(

loss = F.mse_loss(q_pred, q_target)

return loss, QIntermediates(q_pred, q_next, q_target)
# prepare q prediction

q_pred_all_actions = unpack_one(q_pred_all_actions, time_ps, '* a')

return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

def learn(
self,
*args
):
_, _, actions, *_ = args

# main q-learning loss, whether single or n-step

if self.n_step_q_learning:
td_loss, q_intermediates = self.n_step_q_learn(*args)
else:
td_loss, q_intermediates = self.q_learn(*args)

# calculate conservative regularization

batch = actions.shape[0]

q_preds = q_intermediates.q_pred_all_actions
num_action_bins = q_preds.shape[-1]
num_non_dataset_actions = num_action_bins - 1

actions = rearrange(actions, '... -> ... 1')

dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds))

q_actions_not_taken = q_preds[~dataset_action_mask.bool()]
q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions)

conservative_reg_loss = (q_actions_not_taken ** 2).sum() / num_non_dataset_actions

# total loss

loss = 0.5 * td_loss + \
0.5 * conservative_reg_loss * self.conservative_reg_loss_weight

loss_breakdown = Losses(td_loss, conservative_reg_loss)

return loss, loss_breakdown

def forward(self):
step = self.step.item()
Expand All @@ -345,16 +399,12 @@ def forward(self):
# main q-learning algorithm

with self.accelerator.autocast():
data = next(replay_buffer_iter)

if self.n_step_q_learning:
loss, _ = self.n_step_q_learn(*data)
else:
loss, _ = self.q_learn(*data)
loss, (td_loss, conservative_reg_loss) = self.learn(*next(replay_buffer_iter))

self.accelerator.backward(loss)

self.print(f'loss: {loss.item():.3f}')
self.print(f'td loss: {td_loss.item():.3f}')

# take optimizer step

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit e2772a2

Please sign in to comment.