Skip to content

Commit

Permalink
discounting was not done correctly for n-step
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 27, 2023
1 parent 82a99cd commit 52a5a4b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion q_transformer/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def __getitem__(self, _):
action = torch.randint(0, self.num_action_bins + 1, self.time_shape)
next_state = torch.randn(3, *self.video_shape)
reward = torch.randint(0, 2, self.time_shape)
done = torch.randint(0, 2, self.time_shape, dtype = torch.bool)
done = torch.zeros(self.time_shape, dtype = torch.bool)

return instruction, state, action, next_state, reward, done
24 changes: 16 additions & 8 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def cycle(dl):

# tensor helpers

def reverse_cumsum(t):
cumsum = t.cumsum(dim = -1)
return t - cumsum + cumsum[..., -1:]

def batch_select_indices(t, indices):
batch, single_index = t.shape[0], indices.ndim == 1
batch_arange = torch.arange(batch, device = indices.device)
Expand Down Expand Up @@ -115,6 +111,8 @@ def __init__(
self.discount_factor_gamma = discount_factor_gamma
self.n_step_q_learning = n_step_q_learning

self.register_buffer('discount_matrix', None, persistent = False)

# online (evaluated) Q model

self.model = model
Expand Down Expand Up @@ -259,6 +257,17 @@ def q_learn(

return loss, QIntermediates(q_pred, q_next, q_target)

def get_discount_matrix(self, timestep):
if exists(self.discount_matrix) and self.discount_matrix.shape[-1] <= timestep:
return self.discount_matrix[:timestep, :timestep]

timestep_arange = torch.arange(timestep, device = self.accelerator.device)
powers = (timestep_arange[None, :] - timestep_arange[:, None])
discount_matrix = torch.triu(self.discount_factor_gamma ** powers)

self.register_buffer('discount_matrix', discount_matrix, persistent = False)
return self.discount_matrix

def n_step_q_learn(
self,
instructions: Tuple[str],
Expand Down Expand Up @@ -312,12 +321,11 @@ def n_step_q_learn(

rewards, _ = pack([rewards, q_next], 'b *')

powers = torch.arange(num_timesteps + 1, device = device)
γ = γ ** powers
γ = self.get_discount_matrix(num_timesteps + 1)[:-1, :]

# Bellman's equation
# account for discounting using the discount matrix

q_target = reverse_cumsum(not_terminal * rewards * γ)[..., :-1]
q_target = einsum('b t, r t -> b r', not_terminal * rewards, γ)

# have transformer learn to predict above Q target

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 52a5a4b

Please sign in to comment.