-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add c51 for dqn and dqfd #115
Changes from 2 commits
6a40f27
fd3a91b
777693c
75a4ac7
181a866
bfeb011
4e4d5d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,14 +17,14 @@ | |||||
import gym | ||||||
import numpy as np | ||||||
import torch | ||||||
import torch.nn.functional as F | ||||||
from torch.nn.utils import clip_grad_norm_ | ||||||
import wandb | ||||||
|
||||||
from algorithms.common.abstract.agent import AbstractAgent | ||||||
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer | ||||||
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer | ||||||
import algorithms.common.helper_functions as common_utils | ||||||
import algorithms.dqn.utils as dqn_utils | ||||||
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||||
|
||||||
|
@@ -153,30 +153,29 @@ def _get_dqn_loss( | |||||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
"""Return element-wise dqn loss and Q-values.""" | ||||||
states, actions, rewards, next_states, dones = experiences[:5] | ||||||
|
||||||
q_values = self.dqn(states) | ||||||
next_q_values = self.dqn(next_states) | ||||||
next_target_q_values = self.dqn_target(next_states) | ||||||
|
||||||
curr_q_value = q_values.gather(1, actions.long().unsqueeze(1)) | ||||||
next_q_value = next_target_q_values.gather( # Double DQN | ||||||
1, next_q_values.argmax(1).unsqueeze(1) | ||||||
batch_size = self.hyper_params["BATCH_SIZE"] | ||||||
|
||||||
proj_dist = dqn_utils.projection_distribution( | ||||||
model=self.dqn, | ||||||
target_model=self.dqn_target, | ||||||
batch_size=batch_size, | ||||||
next_states=next_states, | ||||||
rewards=rewards, | ||||||
dones=dones, | ||||||
v_min=self.dqn.v_min, | ||||||
v_max=self.dqn.v_max, | ||||||
atom_size=self.dqn.atom_size, | ||||||
gamma=gamma, | ||||||
) | ||||||
|
||||||
# G_t = r + gamma * v(s_{t+1}) if state != Terminal | ||||||
# = r otherwise | ||||||
masks = 1 - dones | ||||||
target = rewards + gamma * next_q_value * masks | ||||||
target = target.to(device) | ||||||
dist, q_values = self.dqn.get_dist_q(states) | ||||||
log_p = torch.log(dist[range(batch_size), actions.long()]) | ||||||
|
||||||
# calculate dq loss | ||||||
dq_loss_element_wise = F.smooth_l1_loss( | ||||||
curr_q_value, target.detach(), reduction="none" | ||||||
) | ||||||
dq_loss_element_wise = -(proj_dist * log_p).sum(1) | ||||||
|
||||||
return dq_loss_element_wise, q_values | ||||||
|
||||||
def update_model(self) -> Tuple[torch.Tensor, ...]: | ||||||
def update_model(self) -> torch.Tensor: | ||||||
"""Train the model after each episode.""" | ||||||
# 1 step loss | ||||||
experiences_1 = self.memory.sample(self.beta) | ||||||
|
@@ -216,7 +215,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]: | |||||
common_utils.soft_update(self.dqn, self.dqn_target, tau) | ||||||
|
||||||
# update priorities in PER | ||||||
loss_for_prior = dq_loss_element_wise.detach().cpu().numpy().squeeze() | ||||||
loss_for_prior = dq_loss_element_wise.detach().cpu().numpy() | ||||||
new_priorities = loss_for_prior + self.hyper_params["PER_EPS"] | ||||||
self.memory.update_priorities(indices, new_priorities) | ||||||
|
||||||
|
@@ -252,7 +251,7 @@ def write_log(self, i: int, loss: np.ndarray, score: int): | |||||
"""Write log about loss and score""" | ||||||
print( | ||||||
"[INFO] episode %d, episode step: %d, total step: %d, total score: %d\n" | ||||||
"epsilon: %f, loss: %f, avg q-value: %f at %s\n" | ||||||
"epsilon: %f, loss: %f, avg_q_value: %f at %s\n" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why no f-string?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't use f-string because it is not compatible with the python versions lower than 3.6. |
||||||
% ( | ||||||
i, | ||||||
self.episode_step, | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,7 +6,7 @@ | |||||
""" | ||||||
|
||||||
from collections import defaultdict | ||||||
from typing import Any, Callable, DefaultDict, Dict | ||||||
from typing import Any, Callable, DefaultDict, Dict, Tuple | ||||||
|
||||||
import torch | ||||||
import torch.nn as nn | ||||||
|
@@ -51,8 +51,8 @@ def __init__( | |||||
self.value_layer.bias.data.uniform_(-init_w, init_w) | ||||||
|
||||||
def _forward_dueling(self, x: torch.Tensor) -> torch.Tensor: | ||||||
adv_x = self.advantage_hidden_layer(x) | ||||||
val_x = self.value_hidden_layer(x) | ||||||
adv_x = self.hidden_activation(self.advantage_hidden_layer(x)) | ||||||
val_x = self.hidden_activation(self.value_hidden_layer(x)) | ||||||
|
||||||
advantage = self.advantage_layer(adv_x) | ||||||
value = self.value_layer(val_x) | ||||||
|
@@ -70,6 +70,73 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
return x | ||||||
|
||||||
|
||||||
class CategoricalDuelingMLP(MLP): | ||||||
"""Multilayer perceptron with dueling construction.""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
input_size: int, | ||||||
action_size: int, | ||||||
hidden_sizes: list, | ||||||
atom_size: int = 51, | ||||||
v_min: int = -10, | ||||||
v_max: int = 10, | ||||||
hidden_activation: Callable = F.relu, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will open an issue for it. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened an issue: #117 |
||||||
init_w: float = 3e-3, | ||||||
): | ||||||
"""Initialization.""" | ||||||
super(CategoricalDuelingMLP, self).__init__( | ||||||
input_size=input_size, | ||||||
output_size=action_size, | ||||||
hidden_sizes=hidden_sizes, | ||||||
hidden_activation=hidden_activation, | ||||||
use_output_layer=False, | ||||||
) | ||||||
in_size = hidden_sizes[-1] | ||||||
self.action_size = action_size | ||||||
self.atom_size = atom_size | ||||||
self.output_size = action_size * atom_size | ||||||
self.v_min, self.v_max = v_min, v_max | ||||||
|
||||||
# set advantage layer | ||||||
self.advantage_hidden_layer = nn.Linear(in_size, in_size) | ||||||
self.advantage_layer = nn.Linear(in_size, self.output_size) | ||||||
self.advantage_layer.weight.data.uniform_(-init_w, init_w) | ||||||
self.advantage_layer.bias.data.uniform_(-init_w, init_w) | ||||||
|
||||||
# set value layer | ||||||
self.value_hidden_layer = nn.Linear(in_size, in_size) | ||||||
self.value_layer = nn.Linear(in_size, self.atom_size) | ||||||
self.value_layer.weight.data.uniform_(-init_w, init_w) | ||||||
self.value_layer.bias.data.uniform_(-init_w, init_w) | ||||||
|
||||||
def get_dist_q(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
"""Get distribution for atoms.""" | ||||||
action_size, atom_size = self.action_size, self.atom_size | ||||||
|
||||||
x = super(CategoricalDuelingMLP, self).forward(x) | ||||||
adv_x = self.hidden_activation(self.advantage_hidden_layer(x)) | ||||||
val_x = self.hidden_activation(self.value_hidden_layer(x)) | ||||||
|
||||||
advantage = self.advantage_layer(adv_x).view(-1, action_size, atom_size) | ||||||
value = self.value_layer(val_x).view(-1, 1, atom_size) | ||||||
advantage_mean = advantage.mean(dim=1, keepdim=True) | ||||||
|
||||||
q_atoms = value + advantage - advantage_mean | ||||||
dist = F.softmax(q_atoms, dim=2) | ||||||
|
||||||
support = torch.linspace(self.v_min, self.v_max, self.atom_size).to(device) | ||||||
q = torch.sum(dist * support, dim=2) | ||||||
|
||||||
return dist, q | ||||||
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||||
"""Forward method implementation.""" | ||||||
_, q = self.get_dist_q(x) | ||||||
|
||||||
return q | ||||||
|
||||||
|
||||||
class LateFusionDuelingMLP(DuelingMLP): | ||||||
"""DuelingMLP with late input fusion. | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Utility functions for DQN. | ||
|
||
This module has DQN util functions. | ||
|
||
- Author: Curt Park | ||
- Contact: [email protected] | ||
""" | ||
|
||
import torch | ||
|
||
from algorithms.dqn.networks import CategoricalDuelingMLP | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def projection_distribution( | ||
model: CategoricalDuelingMLP, | ||
target_model: CategoricalDuelingMLP, | ||
batch_size: int, | ||
next_states: torch.Tensor, | ||
rewards: torch.Tensor, | ||
dones: torch.Tensor, | ||
v_min: int, | ||
v_max: int, | ||
atom_size: int, | ||
gamma: float, | ||
) -> torch.Tensor: | ||
"""Get projection distribution (C51) to calculate dqn loss. """ | ||
support = torch.linspace(v_min, v_max, atom_size).to(device) | ||
delta_z = float(v_max - v_min) / (atom_size - 1) | ||
|
||
with torch.no_grad(): | ||
next_actions = model.get_dist_q(next_states)[1].argmax(1) | ||
next_dist = target_model.get_dist_q(next_states)[0] | ||
next_dist = next_dist[range(batch_size), next_actions] | ||
|
||
t_z = rewards + (1 - dones) * gamma * support | ||
t_z = t_z.clamp(min=v_min, max=v_max) | ||
b = (t_z - v_min) / delta_z | ||
l = b.floor().long() # noqa: E741 | ||
u = b.ceil().long() | ||
|
||
# Fix disappearing probability mass when l = b = u (b is int) | ||
# taken from https://github.com/Kaixhin/Rainbow | ||
l[(u > 0) * (l == u)] -= 1 # noqa: E741 | ||
u[(l < (atom_size - 1)) * (l == u)] += 1 # noqa: E741 | ||
|
||
offset = ( | ||
torch.linspace(0, (batch_size - 1) * atom_size, batch_size) | ||
.long() | ||
.unsqueeze(1) | ||
.expand(batch_size, atom_size) | ||
.to(device) | ||
) | ||
|
||
proj_dist = torch.zeros(next_dist.size(), device=device) | ||
proj_dist.view(-1).index_add_( | ||
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) | ||
) | ||
proj_dist.view(-1).index_add_( | ||
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) | ||
) | ||
|
||
return proj_dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a key instead of using
str
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am looking for any way not to use strings as keys, like enum in c.
It would be better if I don't have to make a new .py to define keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an issue: #116