Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel Q-Networks algorithm (PQN) #472

Merged
merged 10 commits into from
Nov 14, 2024
32 changes: 32 additions & 0 deletions benchmark/pqn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
poetry install
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/pqn.py --no_cuda --track" \
--num-seeds 3 \
--workers 9 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

poetry install -E envpool
poetry run python -m cleanrl_utils.benchmark \
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
--command "poetry run python cleanrl/pqn_atari_envpool.py --track" \
--num-seeds 3 \
--workers 9 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

poetry install -E envpool
poetry run python -m cleanrl_utils.benchmark \
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
--command "poetry run python cleanrl/pqn_atari_envpool_lstm.py --track" \
--num-seeds 3 \
--workers 9 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template
50 changes: 50 additions & 0 deletions benchmark/pqn_plot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

python -m openrlbenchmark.rlops \
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'pqn?tag=pr-472&cl=CleanRL PQN' \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 2 \
--output-filename benchmark/cleanrl/pqn \
--scan-history

python -m openrlbenchmark.rlops \
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'pqn_atari_envpool?tag=pr-472&cl=CleanRL PQN' \
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 3 \
--rliable \
--rc.score_normalization_method maxmin \
--rc.normalized_score_threshold 1.0 \
--rc.sample_efficiency_plots \
--rc.sample_efficiency_and_walltime_efficiency_method Median \
--rc.performance_profile_plots \
--rc.aggregate_metrics_plots \
--rc.sample_efficiency_num_bootstrap_reps 10 \
--rc.performance_profile_num_bootstrap_reps 10 \
--rc.interval_estimates_num_bootstrap_reps 10 \
--output-filename static/0compare \
--scan-history

python -m openrlbenchmark.rlops \
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'pqn_atari_envpool_lstm?tag=pr-472&cl=CleanRL PQN' \
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 MsPacman-v5 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 3 \
--rliable \
--rc.score_normalization_method maxmin \
--rc.normalized_score_threshold 1.0 \
--rc.sample_efficiency_plots \
--rc.sample_efficiency_and_walltime_efficiency_method Median \
--rc.performance_profile_plots \
--rc.aggregate_metrics_plots \
--rc.sample_efficiency_num_bootstrap_reps 10 \
--rc.performance_profile_num_bootstrap_reps 10 \
--rc.interval_estimates_num_bootstrap_reps 10 \
--output-filename static/0compare \
--scan-history
247 changes: 247 additions & 0 deletions cleanrl/pqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy
import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""

# Algorithm specific arguments
env_id: str = "CartPole-v1"
"""the id of the environment"""
total_timesteps: int = 500000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 4
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run for each environment per update"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 4
"""the K epochs to update the policy"""
anneal_lr: bool = True
"""Toggle learning rate annealing"""
gamma: float = 0.99
"""the discount factor gamma"""
start_e: float = 1
"""the starting epsilon for exploration"""
end_e: float = 0.05
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.5
"""the fraction of `total_timesteps` it takes from start_e to end_e"""
max_grad_norm: float = 10.0
"""the maximum norm for the gradient clipping"""
q_lambda: float = 0.65
"""the lambda for Q(lambda)"""


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)

return env

return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()

self.network = nn.Sequential(
layer_init(nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)),
nn.LayerNorm(120),
nn.ReLU(),
layer_init(nn.Linear(120, 84)),
nn.LayerNorm(84),
nn.ReLU(),
layer_init(nn.Linear(84, env.single_action_space.n)),
)

def forward(self, x):
return self.network(x)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)


if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

# agent setup
q_network = QNetwork(envs).to(device)
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)

# storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)

for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow

for step in range(0, args.num_steps):
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done

epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
with torch.no_grad():
q_values = q_network(next_obs)
max_actions = torch.argmax(q_values, dim=1)
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()

explore = torch.rand((args.num_envs,)).to(device) < epsilon
action = torch.where(explore, random_actions, max_actions)
actions[step] = action

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
next_done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# Compute Q(lambda) targets
with torch.no_grad():
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_value, _ = torch.max(q_network(next_obs), dim=-1)
nextnonterminal = 1.0 - next_done
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
else:
nextnonterminal = 1.0 - dones[t + 1]
next_value = values[t + 1]
returns[t] = rewards[t] + args.gamma * (
args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
)

# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_returns = returns.reshape(-1)

# Optimizing the Q-network
b_inds = np.arange(args.batch_size)
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]

old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
loss = F.mse_loss(b_returns[mb_inds], old_val)

# optimize the model
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
optimizer.step()

writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
writer.close()
Loading
Loading