-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Parallel Q-Networks algorithm (PQN) #472
Merged
Merged
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6e89a5c
wip
roger-creus e9f158f
wip
roger-creus 21226d7
Run pre-commit run --all-files
roger-creus fba8994
each env has its own probability of random.random() < epsilon -- i.e.…
roger-creus b6a12c1
added pqn with lstm. works on breakout :)
roger-creus bfe1205
Use buffer for values instead of re-computing in Q(lambda). Speed up …
roger-creus cc66288
Merge remote-tracking branch 'upstream/master'
roger-creus 10334fa
default rollout hyperparameters in PQN are now equal to PPO. same for…
roger-creus c8acf70
Added coumentation. Fixed links to the docs for PQN
roger-creus d9c300b
Finished PQN documentation by adding commands to run the benchmarks a…
roger-creus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy | ||
import os | ||
import random | ||
import time | ||
from dataclasses import dataclass | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import tyro | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
@dataclass | ||
class Args: | ||
exp_name: str = os.path.basename(__file__)[: -len(".py")] | ||
"""the name of this experiment""" | ||
seed: int = 1 | ||
"""seed of the experiment""" | ||
torch_deterministic: bool = True | ||
"""if toggled, `torch.backends.cudnn.deterministic=False`""" | ||
cuda: bool = True | ||
"""if toggled, cuda will be enabled by default""" | ||
track: bool = False | ||
"""if toggled, this experiment will be tracked with Weights and Biases""" | ||
wandb_project_name: str = "cleanRL" | ||
"""the wandb's project name""" | ||
wandb_entity: str = None | ||
"""the entity (team) of wandb's project""" | ||
capture_video: bool = False | ||
"""whether to capture videos of the agent performances (check out `videos` folder)""" | ||
|
||
# Algorithm specific arguments | ||
env_id: str = "CartPole-v1" | ||
"""the id of the environment""" | ||
total_timesteps: int = 500000 | ||
"""total timesteps of the experiments""" | ||
learning_rate: float = 2.5e-4 | ||
"""the learning rate of the optimizer""" | ||
num_envs: int = 4 | ||
"""the number of parallel game environments""" | ||
num_steps: int = 128 | ||
"""the number of steps to run for each environment per update""" | ||
num_minibatches: int = 4 | ||
"""the number of mini-batches""" | ||
update_epochs: int = 4 | ||
"""the K epochs to update the policy""" | ||
anneal_lr: bool = True | ||
"""Toggle learning rate annealing""" | ||
gamma: float = 0.99 | ||
"""the discount factor gamma""" | ||
start_e: float = 1 | ||
"""the starting epsilon for exploration""" | ||
end_e: float = 0.05 | ||
"""the ending epsilon for exploration""" | ||
exploration_fraction: float = 0.5 | ||
"""the fraction of `total_timesteps` it takes from start_e to end_e""" | ||
max_grad_norm: float = 10.0 | ||
"""the maximum norm for the gradient clipping""" | ||
q_lambda: float = 0.65 | ||
"""the lambda for Q(lambda)""" | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
if capture_video and idx == 0: | ||
env = gym.make(env_id, render_mode="rgb_array") | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
else: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
env.action_space.seed(seed) | ||
|
||
return env | ||
|
||
return thunk | ||
|
||
|
||
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | ||
torch.nn.init.orthogonal_(layer.weight, std) | ||
torch.nn.init.constant_(layer.bias, bias_const) | ||
return layer | ||
|
||
|
||
# ALGO LOGIC: initialize agent here: | ||
class QNetwork(nn.Module): | ||
def __init__(self, env): | ||
super().__init__() | ||
|
||
self.network = nn.Sequential( | ||
layer_init(nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)), | ||
nn.LayerNorm(120), | ||
nn.ReLU(), | ||
layer_init(nn.Linear(120, 84)), | ||
nn.LayerNorm(84), | ||
nn.ReLU(), | ||
layer_init(nn.Linear(84, env.single_action_space.n)), | ||
) | ||
|
||
def forward(self, x): | ||
return self.network(x) | ||
|
||
|
||
def linear_schedule(start_e: float, end_e: float, duration: int, t: int): | ||
slope = (end_e - start_e) / duration | ||
return max(slope * t + start_e, end_e) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = tyro.cli(Args) | ||
args.batch_size = int(args.num_envs * args.num_steps) | ||
args.minibatch_size = int(args.batch_size // args.num_minibatches) | ||
args.num_iterations = args.total_timesteps // args.batch_size | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv( | ||
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] | ||
) | ||
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
|
||
# agent setup | ||
q_network = QNetwork(envs).to(device) | ||
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate) | ||
|
||
# storage setup | ||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) | ||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) | ||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
values = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
global_step = 0 | ||
start_time = time.time() | ||
next_obs, _ = envs.reset(seed=args.seed) | ||
next_obs = torch.Tensor(next_obs).to(device) | ||
next_done = torch.zeros(args.num_envs).to(device) | ||
|
||
for iteration in range(1, args.num_iterations + 1): | ||
# Annealing the rate if instructed to do so. | ||
if args.anneal_lr: | ||
frac = 1.0 - (iteration - 1.0) / args.num_iterations | ||
lrnow = frac * args.learning_rate | ||
optimizer.param_groups[0]["lr"] = lrnow | ||
|
||
for step in range(0, args.num_steps): | ||
global_step += args.num_envs | ||
obs[step] = next_obs | ||
dones[step] = next_done | ||
|
||
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) | ||
random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device) | ||
with torch.no_grad(): | ||
q_values = q_network(next_obs) | ||
max_actions = torch.argmax(q_values, dim=1) | ||
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten() | ||
|
||
explore = torch.rand((args.num_envs,)).to(device) < epsilon | ||
action = torch.where(explore, random_actions, max_actions) | ||
actions[step] = action | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) | ||
next_done = np.logical_or(terminations, truncations) | ||
rewards[step] = torch.tensor(reward).to(device).view(-1) | ||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) | ||
|
||
if "final_info" in infos: | ||
for info in infos["final_info"]: | ||
if info and "episode" in info: | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
|
||
# Compute Q(lambda) targets | ||
with torch.no_grad(): | ||
returns = torch.zeros_like(rewards).to(device) | ||
for t in reversed(range(args.num_steps)): | ||
if t == args.num_steps - 1: | ||
next_value, _ = torch.max(q_network(next_obs), dim=-1) | ||
nextnonterminal = 1.0 - next_done | ||
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal | ||
else: | ||
nextnonterminal = 1.0 - dones[t + 1] | ||
next_value = values[t + 1] | ||
returns[t] = rewards[t] + args.gamma * ( | ||
args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal | ||
) | ||
|
||
# flatten the batch | ||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) | ||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape) | ||
b_returns = returns.reshape(-1) | ||
|
||
# Optimizing the Q-network | ||
b_inds = np.arange(args.batch_size) | ||
for epoch in range(args.update_epochs): | ||
np.random.shuffle(b_inds) | ||
for start in range(0, args.batch_size, args.minibatch_size): | ||
end = start + args.minibatch_size | ||
mb_inds = b_inds[start:end] | ||
|
||
old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze() | ||
loss = F.mse_loss(b_returns[mb_inds], old_val) | ||
|
||
# optimize the model | ||
optimizer.zero_grad() | ||
loss.backward() | ||
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm) | ||
optimizer.step() | ||
|
||
writer.add_scalar("losses/td_loss", loss, global_step) | ||
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be updated