diff --git a/benchmark/pqn.sh b/benchmark/pqn.sh
new file mode 100644
index 00000000..1aed60bf
--- /dev/null
+++ b/benchmark/pqn.sh
@@ -0,0 +1,32 @@
+poetry install
+OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
+ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
+ --command "poetry run python cleanrl/pqn.py --no_cuda --track" \
+ --num-seeds 3 \
+ --workers 9 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
+
+poetry install -E envpool
+poetry run python -m cleanrl_utils.benchmark \
+ --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
+ --command "poetry run python cleanrl/pqn_atari_envpool.py --track" \
+ --num-seeds 3 \
+ --workers 9 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
+
+poetry install -E envpool
+poetry run python -m cleanrl_utils.benchmark \
+ --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
+ --command "poetry run python cleanrl/pqn_atari_envpool_lstm.py --track" \
+ --num-seeds 3 \
+ --workers 9 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
\ No newline at end of file
diff --git a/benchmark/pqn_plot.sh b/benchmark/pqn_plot.sh
new file mode 100644
index 00000000..1c9237d2
--- /dev/null
+++ b/benchmark/pqn_plot.sh
@@ -0,0 +1,50 @@
+
+python -m openrlbenchmark.rlops \
+ --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'pqn?tag=pr-472&cl=CleanRL PQN' \
+ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 2 \
+ --output-filename benchmark/cleanrl/pqn \
+ --scan-history
+
+python -m openrlbenchmark.rlops \
+ --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'pqn_atari_envpool?tag=pr-472&cl=CleanRL PQN' \
+ --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 3 \
+ --rliable \
+ --rc.score_normalization_method maxmin \
+ --rc.normalized_score_threshold 1.0 \
+ --rc.sample_efficiency_plots \
+ --rc.sample_efficiency_and_walltime_efficiency_method Median \
+ --rc.performance_profile_plots \
+ --rc.aggregate_metrics_plots \
+ --rc.sample_efficiency_num_bootstrap_reps 10 \
+ --rc.performance_profile_num_bootstrap_reps 10 \
+ --rc.interval_estimates_num_bootstrap_reps 10 \
+ --output-filename static/0compare \
+ --scan-history
+
+python -m openrlbenchmark.rlops \
+ --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'pqn_atari_envpool_lstm?tag=pr-472&cl=CleanRL PQN' \
+ --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 MsPacman-v5 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 3 \
+ --rliable \
+ --rc.score_normalization_method maxmin \
+ --rc.normalized_score_threshold 1.0 \
+ --rc.sample_efficiency_plots \
+ --rc.sample_efficiency_and_walltime_efficiency_method Median \
+ --rc.performance_profile_plots \
+ --rc.aggregate_metrics_plots \
+ --rc.sample_efficiency_num_bootstrap_reps 10 \
+ --rc.performance_profile_num_bootstrap_reps 10 \
+ --rc.interval_estimates_num_bootstrap_reps 10 \
+ --output-filename static/0compare \
+ --scan-history
\ No newline at end of file
diff --git a/cleanrl/pqn.py b/cleanrl/pqn.py
new file mode 100644
index 00000000..6ed6e205
--- /dev/null
+++ b/cleanrl/pqn.py
@@ -0,0 +1,247 @@
+# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy
+import os
+import random
+import time
+from dataclasses import dataclass
+
+import gymnasium as gym
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import tyro
+from torch.utils.tensorboard import SummaryWriter
+
+
+@dataclass
+class Args:
+ exp_name: str = os.path.basename(__file__)[: -len(".py")]
+ """the name of this experiment"""
+ seed: int = 1
+ """seed of the experiment"""
+ torch_deterministic: bool = True
+ """if toggled, `torch.backends.cudnn.deterministic=False`"""
+ cuda: bool = True
+ """if toggled, cuda will be enabled by default"""
+ track: bool = False
+ """if toggled, this experiment will be tracked with Weights and Biases"""
+ wandb_project_name: str = "cleanRL"
+ """the wandb's project name"""
+ wandb_entity: str = None
+ """the entity (team) of wandb's project"""
+ capture_video: bool = False
+ """whether to capture videos of the agent performances (check out `videos` folder)"""
+
+ # Algorithm specific arguments
+ env_id: str = "CartPole-v1"
+ """the id of the environment"""
+ total_timesteps: int = 500000
+ """total timesteps of the experiments"""
+ learning_rate: float = 2.5e-4
+ """the learning rate of the optimizer"""
+ num_envs: int = 4
+ """the number of parallel game environments"""
+ num_steps: int = 128
+ """the number of steps to run for each environment per update"""
+ num_minibatches: int = 4
+ """the number of mini-batches"""
+ update_epochs: int = 4
+ """the K epochs to update the policy"""
+ anneal_lr: bool = True
+ """Toggle learning rate annealing"""
+ gamma: float = 0.99
+ """the discount factor gamma"""
+ start_e: float = 1
+ """the starting epsilon for exploration"""
+ end_e: float = 0.05
+ """the ending epsilon for exploration"""
+ exploration_fraction: float = 0.5
+ """the fraction of `total_timesteps` it takes from start_e to end_e"""
+ max_grad_norm: float = 10.0
+ """the maximum norm for the gradient clipping"""
+ q_lambda: float = 0.65
+ """the lambda for Q(lambda)"""
+
+
+def make_env(env_id, seed, idx, capture_video, run_name):
+ def thunk():
+ if capture_video and idx == 0:
+ env = gym.make(env_id, render_mode="rgb_array")
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
+ else:
+ env = gym.make(env_id)
+ env = gym.wrappers.RecordEpisodeStatistics(env)
+ env.action_space.seed(seed)
+
+ return env
+
+ return thunk
+
+
+def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
+ torch.nn.init.orthogonal_(layer.weight, std)
+ torch.nn.init.constant_(layer.bias, bias_const)
+ return layer
+
+
+# ALGO LOGIC: initialize agent here:
+class QNetwork(nn.Module):
+ def __init__(self, env):
+ super().__init__()
+
+ self.network = nn.Sequential(
+ layer_init(nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)),
+ nn.LayerNorm(120),
+ nn.ReLU(),
+ layer_init(nn.Linear(120, 84)),
+ nn.LayerNorm(84),
+ nn.ReLU(),
+ layer_init(nn.Linear(84, env.single_action_space.n)),
+ )
+
+ def forward(self, x):
+ return self.network(x)
+
+
+def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
+ slope = (end_e - start_e) / duration
+ return max(slope * t + start_e, end_e)
+
+
+if __name__ == "__main__":
+ args = tyro.cli(Args)
+ args.batch_size = int(args.num_envs * args.num_steps)
+ args.minibatch_size = int(args.batch_size // args.num_minibatches)
+ args.num_iterations = args.total_timesteps // args.batch_size
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
+ if args.track:
+ import wandb
+
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ sync_tensorboard=True,
+ config=vars(args),
+ name=run_name,
+ monitor_gym=True,
+ save_code=True,
+ )
+ writer = SummaryWriter(f"runs/{run_name}")
+ writer.add_text(
+ "hyperparameters",
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ )
+
+ # TRY NOT TO MODIFY: seeding
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = args.torch_deterministic
+
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+
+ # env setup
+ envs = gym.vector.SyncVectorEnv(
+ [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
+ )
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
+
+ # agent setup
+ q_network = QNetwork(envs).to(device)
+ optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)
+
+ # storage setup
+ obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
+ actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
+ rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ values = torch.zeros((args.num_steps, args.num_envs)).to(device)
+
+ # TRY NOT TO MODIFY: start the game
+ global_step = 0
+ start_time = time.time()
+ next_obs, _ = envs.reset(seed=args.seed)
+ next_obs = torch.Tensor(next_obs).to(device)
+ next_done = torch.zeros(args.num_envs).to(device)
+
+ for iteration in range(1, args.num_iterations + 1):
+ # Annealing the rate if instructed to do so.
+ if args.anneal_lr:
+ frac = 1.0 - (iteration - 1.0) / args.num_iterations
+ lrnow = frac * args.learning_rate
+ optimizer.param_groups[0]["lr"] = lrnow
+
+ for step in range(0, args.num_steps):
+ global_step += args.num_envs
+ obs[step] = next_obs
+ dones[step] = next_done
+
+ epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
+ random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
+ with torch.no_grad():
+ q_values = q_network(next_obs)
+ max_actions = torch.argmax(q_values, dim=1)
+ values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()
+
+ explore = torch.rand((args.num_envs,)).to(device) < epsilon
+ action = torch.where(explore, random_actions, max_actions)
+ actions[step] = action
+
+ # TRY NOT TO MODIFY: execute the game and log data.
+ next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
+ next_done = np.logical_or(terminations, truncations)
+ rewards[step] = torch.tensor(reward).to(device).view(-1)
+ next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
+
+ if "final_info" in infos:
+ for info in infos["final_info"]:
+ if info and "episode" in info:
+ print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
+ writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
+ writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
+
+ # Compute Q(lambda) targets
+ with torch.no_grad():
+ returns = torch.zeros_like(rewards).to(device)
+ for t in reversed(range(args.num_steps)):
+ if t == args.num_steps - 1:
+ next_value, _ = torch.max(q_network(next_obs), dim=-1)
+ nextnonterminal = 1.0 - next_done
+ returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
+ else:
+ nextnonterminal = 1.0 - dones[t + 1]
+ next_value = values[t + 1]
+ returns[t] = rewards[t] + args.gamma * (
+ args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
+ )
+
+ # flatten the batch
+ b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
+ b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
+ b_returns = returns.reshape(-1)
+
+ # Optimizing the Q-network
+ b_inds = np.arange(args.batch_size)
+ for epoch in range(args.update_epochs):
+ np.random.shuffle(b_inds)
+ for start in range(0, args.batch_size, args.minibatch_size):
+ end = start + args.minibatch_size
+ mb_inds = b_inds[start:end]
+
+ old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
+ loss = F.mse_loss(b_returns[mb_inds], old_val)
+
+ # optimize the model
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
+ optimizer.step()
+
+ writer.add_scalar("losses/td_loss", loss, global_step)
+ writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
+ print("SPS:", int(global_step / (time.time() - start_time)))
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
+
+ envs.close()
+ writer.close()
diff --git a/cleanrl/pqn_atari_envpool.py b/cleanrl/pqn_atari_envpool.py
new file mode 100644
index 00000000..45fd5a4c
--- /dev/null
+++ b/cleanrl/pqn_atari_envpool.py
@@ -0,0 +1,290 @@
+# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqn_atari_envpoolpy
+import os
+import random
+import time
+from collections import deque
+from dataclasses import dataclass
+
+import envpool
+import gym
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import tyro
+from torch.utils.tensorboard import SummaryWriter
+
+
+@dataclass
+class Args:
+ exp_name: str = os.path.basename(__file__)[: -len(".py")]
+ """the name of this experiment"""
+ seed: int = 1
+ """seed of the experiment"""
+ torch_deterministic: bool = True
+ """if toggled, `torch.backends.cudnn.deterministic=False`"""
+ cuda: bool = True
+ """if toggled, cuda will be enabled by default"""
+ track: bool = False
+ """if toggled, this experiment will be tracked with Weights and Biases"""
+ wandb_project_name: str = "cleanRL"
+ """the wandb's project name"""
+ wandb_entity: str = None
+ """the entity (team) of wandb's project"""
+ capture_video: bool = False
+ """whether to capture videos of the agent performances (check out `videos` folder)"""
+
+ # Algorithm specific arguments
+ env_id: str = "Breakout-v5"
+ """the id of the environment"""
+ total_timesteps: int = 10000000
+ """total timesteps of the experiments"""
+ learning_rate: float = 2.5e-4
+ """the learning rate of the optimizer"""
+ num_envs: int = 8
+ """the number of parallel game environments"""
+ num_steps: int = 128
+ """the number of steps to run in each environment per policy rollout"""
+ anneal_lr: bool = True
+ """Toggle learning rate annealing for policy and value networks"""
+ gamma: float = 0.99
+ """the discount factor gamma"""
+ num_minibatches: int = 4
+ """the number of mini-batches"""
+ update_epochs: int = 4
+ """the K epochs to update the policy"""
+ max_grad_norm: float = 10.0
+ """the maximum norm for the gradient clipping"""
+ start_e: float = 1
+ """the starting epsilon for exploration"""
+ end_e: float = 0.01
+ """the ending epsilon for exploration"""
+ exploration_fraction: float = 0.10
+ """the fraction of `total_timesteps` it takes from start_e to end_e"""
+ q_lambda: float = 0.65
+ """the lambda for the Q-Learning algorithm"""
+
+ # to be filled in runtime
+ batch_size: int = 0
+ """the batch size (computed in runtime)"""
+ minibatch_size: int = 0
+ """the mini-batch size (computed in runtime)"""
+ num_iterations: int = 0
+ """the number of iterations (computed in runtime)"""
+
+
+class RecordEpisodeStatistics(gym.Wrapper):
+ def __init__(self, env, deque_size=100):
+ super().__init__(env)
+ self.num_envs = getattr(env, "num_envs", 1)
+ self.episode_returns = None
+ self.episode_lengths = None
+
+ def reset(self, **kwargs):
+ observations = super().reset(**kwargs)
+ self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ self.lives = np.zeros(self.num_envs, dtype=np.int32)
+ self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ return observations
+
+ def step(self, action):
+ observations, rewards, dones, infos = super().step(action)
+ self.episode_returns += infos["reward"]
+ self.episode_lengths += 1
+ self.returned_episode_returns[:] = self.episode_returns
+ self.returned_episode_lengths[:] = self.episode_lengths
+ self.episode_returns *= 1 - infos["terminated"]
+ self.episode_lengths *= 1 - infos["terminated"]
+ infos["r"] = self.returned_episode_returns
+ infos["l"] = self.returned_episode_lengths
+ return (
+ observations,
+ rewards,
+ dones,
+ infos,
+ )
+
+
+def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
+ torch.nn.init.orthogonal_(layer.weight, std)
+ torch.nn.init.constant_(layer.bias, bias_const)
+ return layer
+
+
+class QNetwork(nn.Module):
+ def __init__(self, env):
+ super().__init__()
+ self.network = nn.Sequential(
+ layer_init(nn.Conv2d(4, 32, 8, stride=4)),
+ nn.LayerNorm([32, 20, 20]),
+ nn.ReLU(),
+ layer_init(nn.Conv2d(32, 64, 4, stride=2)),
+ nn.LayerNorm([64, 9, 9]),
+ nn.ReLU(),
+ layer_init(nn.Conv2d(64, 64, 3, stride=1)),
+ nn.LayerNorm([64, 7, 7]),
+ nn.ReLU(),
+ nn.Flatten(),
+ layer_init(nn.Linear(3136, 512)),
+ nn.LayerNorm(512),
+ nn.ReLU(),
+ layer_init(nn.Linear(512, env.single_action_space.n)),
+ )
+
+ def forward(self, x):
+ return self.network(x / 255.0)
+
+
+def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
+ slope = (end_e - start_e) / duration
+ return max(slope * t + start_e, end_e)
+
+
+if __name__ == "__main__":
+ args = tyro.cli(Args)
+ args.batch_size = int(args.num_envs * args.num_steps)
+ args.minibatch_size = int(args.batch_size // args.num_minibatches)
+ args.num_iterations = args.total_timesteps // args.batch_size
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
+ if args.track:
+ import wandb
+
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ sync_tensorboard=True,
+ config=vars(args),
+ name=run_name,
+ monitor_gym=True,
+ save_code=True,
+ )
+ writer = SummaryWriter(f"runs/{run_name}")
+ writer.add_text(
+ "hyperparameters",
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ )
+
+ # TRY NOT TO MODIFY: seeding
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = args.torch_deterministic
+
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+
+ # env setup
+ envs = envpool.make(
+ args.env_id,
+ env_type="gym",
+ num_envs=args.num_envs,
+ episodic_life=True,
+ reward_clip=True,
+ seed=args.seed,
+ )
+ envs.num_envs = args.num_envs
+ envs.single_action_space = envs.action_space
+ envs.single_observation_space = envs.observation_space
+ envs = RecordEpisodeStatistics(envs)
+ assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
+
+ q_network = QNetwork(envs).to(device)
+ optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)
+
+ # ALGO Logic: Storage setup
+ obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
+ actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
+ rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ values = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ avg_returns = deque(maxlen=20)
+
+ # TRY NOT TO MODIFY: start the game
+ global_step = 0
+ start_time = time.time()
+ next_obs = torch.Tensor(envs.reset()).to(device)
+ next_done = torch.zeros(args.num_envs).to(device)
+
+ for iteration in range(1, args.num_iterations + 1):
+ # Annealing the rate if instructed to do so.
+ if args.anneal_lr:
+ frac = 1.0 - (iteration - 1.0) / args.num_iterations
+ lrnow = frac * args.learning_rate
+ optimizer.param_groups[0]["lr"] = lrnow
+
+ for step in range(0, args.num_steps):
+ global_step += args.num_envs
+ obs[step] = next_obs
+ dones[step] = next_done
+
+ epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
+
+ random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
+ with torch.no_grad():
+ q_values = q_network(next_obs)
+ max_actions = torch.argmax(q_values, dim=1)
+ values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()
+
+ explore = torch.rand((args.num_envs,)).to(device) < epsilon
+ action = torch.where(explore, random_actions, max_actions)
+ actions[step] = action
+
+ # TRY NOT TO MODIFY: execute the game and log data.
+ next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
+ rewards[step] = torch.tensor(reward).to(device).view(-1)
+ next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
+
+ for idx, d in enumerate(next_done):
+ if d and info["lives"][idx] == 0:
+ print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
+ avg_returns.append(info["r"][idx])
+ writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
+ writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
+ writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
+
+ # Compute Q(lambda) targets
+ with torch.no_grad():
+ returns = torch.zeros_like(rewards).to(device)
+ for t in reversed(range(args.num_steps)):
+ if t == args.num_steps - 1:
+ next_value, _ = torch.max(q_network(next_obs), dim=-1)
+ nextnonterminal = 1.0 - next_done
+ returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
+ else:
+ nextnonterminal = 1.0 - dones[t + 1]
+ next_value = values[t + 1]
+ returns[t] = rewards[t] + args.gamma * (
+ args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
+ )
+
+ # flatten the batch
+ b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
+ b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
+ b_returns = returns.reshape(-1)
+
+ # Optimizing the Q-network
+ b_inds = np.arange(args.batch_size)
+ for epoch in range(args.update_epochs):
+ np.random.shuffle(b_inds)
+ for start in range(0, args.batch_size, args.minibatch_size):
+ end = start + args.minibatch_size
+ mb_inds = b_inds[start:end]
+
+ old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
+ loss = F.mse_loss(b_returns[mb_inds], old_val)
+
+ # optimize the model
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
+ optimizer.step()
+
+ writer.add_scalar("losses/td_loss", loss, global_step)
+ writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
+ print("SPS:", int(global_step / (time.time() - start_time)))
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
+
+ envs.close()
+ writer.close()
diff --git a/cleanrl/pqn_atari_envpool_lstm.py b/cleanrl/pqn_atari_envpool_lstm.py
new file mode 100644
index 00000000..6b348b0a
--- /dev/null
+++ b/cleanrl/pqn_atari_envpool_lstm.py
@@ -0,0 +1,338 @@
+# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqn_atari_envpool_lstmpy
+import os
+import random
+import time
+from collections import deque
+from dataclasses import dataclass
+
+import envpool
+import gym
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import tyro
+from torch.utils.tensorboard import SummaryWriter
+
+
+@dataclass
+class Args:
+ exp_name: str = os.path.basename(__file__)[: -len(".py")]
+ """the name of this experiment"""
+ seed: int = 1
+ """seed of the experiment"""
+ torch_deterministic: bool = True
+ """if toggled, `torch.backends.cudnn.deterministic=False`"""
+ cuda: bool = True
+ """if toggled, cuda will be enabled by default"""
+ track: bool = False
+ """if toggled, this experiment will be tracked with Weights and Biases"""
+ wandb_project_name: str = "cleanRL"
+ """the wandb's project name"""
+ wandb_entity: str = None
+ """the entity (team) of wandb's project"""
+ capture_video: bool = False
+ """whether to capture videos of the agent performances (check out `videos` folder)"""
+
+ # Algorithm specific arguments
+ env_id: str = "Breakout-v5"
+ """the id of the environment"""
+ total_timesteps: int = 10000000
+ """total timesteps of the experiments"""
+ learning_rate: float = 2.5e-4
+ """the learning rate of the optimizer"""
+ num_envs: int = 8
+ """the number of parallel game environments"""
+ num_steps: int = 128
+ """the number of steps to run in each environment per policy rollout"""
+ anneal_lr: bool = True
+ """Toggle learning rate annealing for policy and value networks"""
+ gamma: float = 0.99
+ """the discount factor gamma"""
+ num_minibatches: int = 4
+ """the number of mini-batches"""
+ update_epochs: int = 4
+ """the K epochs to update the policy"""
+ max_grad_norm: float = 0.5
+ """the maximum norm for the gradient clipping"""
+ start_e: float = 1
+ """the starting epsilon for exploration"""
+ end_e: float = 0.01
+ """the ending epsilon for exploration"""
+ exploration_fraction: float = 0.10
+ """the fraction of `total_timesteps` it takes from start_e to end_e"""
+ q_lambda: float = 0.65
+ """the lambda for the Q-Learning algorithm"""
+
+ # to be filled in runtime
+ batch_size: int = 0
+ """the batch size (computed in runtime)"""
+ minibatch_size: int = 0
+ """the mini-batch size (computed in runtime)"""
+ num_iterations: int = 0
+ """the number of iterations (computed in runtime)"""
+
+
+class RecordEpisodeStatistics(gym.Wrapper):
+ def __init__(self, env, deque_size=100):
+ super().__init__(env)
+ self.num_envs = getattr(env, "num_envs", 1)
+ self.episode_returns = None
+ self.episode_lengths = None
+
+ def reset(self, **kwargs):
+ observations = super().reset(**kwargs)
+ self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ self.lives = np.zeros(self.num_envs, dtype=np.int32)
+ self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ return observations
+
+ def step(self, action):
+ observations, rewards, dones, infos = super().step(action)
+ self.episode_returns += infos["reward"]
+ self.episode_lengths += 1
+ self.returned_episode_returns[:] = self.episode_returns
+ self.returned_episode_lengths[:] = self.episode_lengths
+ self.episode_returns *= 1 - infos["terminated"]
+ self.episode_lengths *= 1 - infos["terminated"]
+ infos["r"] = self.returned_episode_returns
+ infos["l"] = self.returned_episode_lengths
+ return (
+ observations,
+ rewards,
+ dones,
+ infos,
+ )
+
+
+def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
+ torch.nn.init.orthogonal_(layer.weight, std)
+ torch.nn.init.constant_(layer.bias, bias_const)
+ return layer
+
+
+class QNetwork(nn.Module):
+ def __init__(self, env):
+ super().__init__()
+ self.network = nn.Sequential(
+ layer_init(nn.Conv2d(1, 32, 8, stride=4)),
+ nn.LayerNorm([32, 20, 20]),
+ nn.ReLU(),
+ layer_init(nn.Conv2d(32, 64, 4, stride=2)),
+ nn.LayerNorm([64, 9, 9]),
+ nn.ReLU(),
+ layer_init(nn.Conv2d(64, 64, 3, stride=1)),
+ nn.LayerNorm([64, 7, 7]),
+ nn.ReLU(),
+ nn.Flatten(),
+ layer_init(nn.Linear(3136, 512)),
+ nn.LayerNorm(512),
+ nn.ReLU(),
+ )
+ self.lstm = nn.LSTM(512, 128)
+ for name, param in self.lstm.named_parameters():
+ if "bias" in name:
+ nn.init.constant_(param, 0)
+ elif "weight" in name:
+ nn.init.orthogonal_(param, 1.0)
+ self.q_func = layer_init(nn.Linear(128, env.single_action_space.n))
+
+ def get_states(self, x, lstm_state, done):
+ hidden = self.network(x / 255.0)
+
+ # LSTM logic
+ batch_size = lstm_state[0].shape[1]
+ hidden = hidden.reshape((-1, batch_size, self.lstm.input_size))
+ done = done.reshape((-1, batch_size))
+ new_hidden = []
+ for h, d in zip(hidden, done):
+ h, lstm_state = self.lstm(
+ h.unsqueeze(0),
+ (
+ (1.0 - d).view(1, -1, 1) * lstm_state[0],
+ (1.0 - d).view(1, -1, 1) * lstm_state[1],
+ ),
+ )
+ new_hidden += [h]
+ new_hidden = torch.flatten(torch.cat(new_hidden), 0, 1)
+ return new_hidden, lstm_state
+
+ def forward(self, x, lstm_state, done):
+ hidden, lstm_state = self.get_states(x, lstm_state, done)
+ return self.q_func(hidden), lstm_state
+
+
+def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
+ slope = (end_e - start_e) / duration
+ return max(slope * t + start_e, end_e)
+
+
+if __name__ == "__main__":
+ args = tyro.cli(Args)
+ args.batch_size = int(args.num_envs * args.num_steps)
+ args.minibatch_size = int(args.batch_size // args.num_minibatches)
+ args.num_iterations = args.total_timesteps // args.batch_size
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
+ if args.track:
+ import wandb
+
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ sync_tensorboard=True,
+ config=vars(args),
+ name=run_name,
+ monitor_gym=True,
+ save_code=True,
+ )
+ writer = SummaryWriter(f"runs/{run_name}")
+ writer.add_text(
+ "hyperparameters",
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ )
+
+ # TRY NOT TO MODIFY: seeding
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = args.torch_deterministic
+
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+
+ # env setup
+ envs = envpool.make(
+ args.env_id,
+ env_type="gym",
+ num_envs=args.num_envs,
+ episodic_life=True,
+ reward_clip=True,
+ seed=args.seed,
+ stack_num=1,
+ )
+ envs.num_envs = args.num_envs
+ envs.single_action_space = envs.action_space
+ envs.single_observation_space = envs.observation_space
+ envs = RecordEpisodeStatistics(envs)
+ assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
+
+ q_network = QNetwork(envs).to(device)
+ optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)
+
+ # ALGO Logic: Storage setup
+ obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
+ actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
+ rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ values = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ avg_returns = deque(maxlen=20)
+
+ # TRY NOT TO MODIFY: start the game
+ global_step = 0
+ start_time = time.time()
+ next_obs = torch.Tensor(envs.reset()).to(device)
+ next_done = torch.zeros(args.num_envs).to(device)
+
+ next_lstm_state = (
+ torch.zeros(q_network.lstm.num_layers, args.num_envs, q_network.lstm.hidden_size).to(device),
+ torch.zeros(q_network.lstm.num_layers, args.num_envs, q_network.lstm.hidden_size).to(device),
+ ) # hidden and cell states (see https://youtu.be/8HyCNIVRbSU)
+
+ for iteration in range(1, args.num_iterations + 1):
+ initial_lstm_state = (next_lstm_state[0].clone(), next_lstm_state[1].clone())
+
+ # Annealing the rate if instructed to do so.
+ if args.anneal_lr:
+ frac = 1.0 - (iteration - 1.0) / args.num_iterations
+ lrnow = frac * args.learning_rate
+ optimizer.param_groups[0]["lr"] = lrnow
+
+ for step in range(0, args.num_steps):
+ global_step += args.num_envs
+ obs[step] = next_obs
+ dones[step] = next_done
+
+ epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
+
+ random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
+ with torch.no_grad():
+ q_values, next_lstm_state = q_network(next_obs, next_lstm_state, next_done)
+ max_actions = torch.argmax(q_values, dim=1)
+ values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()
+
+ explore = torch.rand((args.num_envs,)).to(device) < epsilon
+ action = torch.where(explore, random_actions, max_actions)
+ actions[step] = action
+
+ # TRY NOT TO MODIFY: execute the game and log data.
+ next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
+ rewards[step] = torch.tensor(reward).to(device).view(-1)
+ next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
+
+ for idx, d in enumerate(next_done):
+ if d and info["lives"][idx] == 0:
+ print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
+ avg_returns.append(info["r"][idx])
+ writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
+ writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
+ writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
+
+ # Compute Q(lambda) targets
+ with torch.no_grad():
+ returns = torch.zeros_like(rewards).to(device)
+ for t in reversed(range(args.num_steps)):
+ if t == args.num_steps - 1:
+ next_value, _ = torch.max(q_network(next_obs, next_lstm_state, next_done)[0], dim=-1)
+ nextnonterminal = 1.0 - next_done
+ returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
+ else:
+ nextnonterminal = 1.0 - dones[t + 1]
+ next_value = values[t + 1]
+ returns[t] = rewards[t] + args.gamma * (
+ args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
+ )
+
+ # flatten the batch
+ b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
+ b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
+ b_returns = returns.reshape(-1)
+ b_dones = dones.reshape(-1)
+
+ assert args.num_envs % args.num_minibatches == 0
+ envsperbatch = args.num_envs // args.num_minibatches
+ envinds = np.arange(args.num_envs)
+ flatinds = np.arange(args.batch_size).reshape(args.num_steps, args.num_envs)
+
+ # Optimizing the Q-network
+ b_inds = np.arange(args.batch_size)
+ for epoch in range(args.update_epochs):
+ np.random.shuffle(envinds)
+ for start in range(0, args.num_envs, envsperbatch):
+ end = start + envsperbatch
+ mbenvinds = envinds[start:end]
+ mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
+
+ old_val, _ = q_network(
+ b_obs[mb_inds],
+ (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
+ b_dones[mb_inds],
+ )
+ old_val = old_val.gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
+
+ loss = F.mse_loss(b_returns[mb_inds], old_val)
+
+ # optimize the model
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
+ optimizer.step()
+
+ writer.add_scalar("losses/td_loss", loss, global_step)
+ writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
+ print("SPS:", int(global_step / (time.time() - start_time)))
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
+
+ envs.close()
+ writer.close()
diff --git a/docs/rl-algorithms/pqn.md b/docs/rl-algorithms/pqn.md
new file mode 100644
index 00000000..c6c86e76
--- /dev/null
+++ b/docs/rl-algorithms/pqn.md
@@ -0,0 +1,297 @@
+# Parallel Q Network (PQN)
+
+
+## Overview
+
+PQN is a parallelized version of the Deep Q-learning algorithm. It is designed to be more efficient than DQN by using multiple agents to interact with the environment in parallel. PQN can be thought of as DQN (1) without replay buffer and target networks, and (2) with layer normalizations and parallel environments.
+
+Original paper:
+
+* [Simplifying Deep Temporal Difference Learning](https://arxiv.org/html/2407.04811v2)
+
+Reference resources:
+
+* :material-github: [purejaxql](https://github.com/mttga/purejaxql)
+
+## Implemented Variants
+
+
+| Variants Implemented | Description |
+| ----------- | ----------- |
+| :material-github: [`pqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn.py), :material-file-document: [docs](/rl-algorithms/pqn/#pqnpy) | For classic control tasks like `CartPole-v1`. |
+| :material-github: [`pqn_atari_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool.py), :material-file-document: [docs](/rl-algorithms/pqn/#pqn_atari_envpoolpy) | For Atari games. Uses the blazing fast Envpool Atari vectorized environment. It uses convolutional layers and common atari-based pre-processing techniques. |
+| :material-github: [`pqn_atari_envpool_lstm.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool_lstm.py), :material-file-document: [docs](/rl-algorithms/pqn/#pqn_atari_envpool_lstmpy) | For Atari games. Uses the blazing fast Envpool Atari vectorized environment. Using LSTM without stacked frames. |
+
+Below are our single-file implementations of PQN:
+
+## `pqn.py`
+
+The [pqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn.py) has the following features:
+
+* Works with the `Box` observation space of low-level features
+* Works with the `Discrete` action space
+* Works with envs like `CartPole-v1`
+
+### Usage
+
+=== "poetry"
+
+ ```bash
+ poetry install
+ poetry run python cleanrl/pqn.py --help
+ poetry run python cleanrl/pqn.py --env-id CartPole-v1
+ ```
+
+=== "pip"
+
+ ```bash
+ python cleanrl/pqn.py --help
+ python cleanrl/pqn.py --env-id CartPole-v1
+ ```
+
+### Explanation of the logged metrics
+
+Running `python cleanrl/pqn.py` will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:
+
+* `charts/episodic_return`: episodic return of the game
+* `charts/episodic_length`: episodic length of the game
+* `charts/SPS`: number of steps per second
+* `charts/learning_rate`: the current learning rate
+* `losses/td_loss`: the mean squared error (MSE) between the Q values at timestep $t$ and the Bellman update target estimated using the $Q(\lambda)$ returns.
+* `losses/q_values`: it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens.
+
+### Implementation details
+
+1. Vectorized architecture (:material-github: [common/cmd_util.py#L22](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/cmd_util.py#L22))
+2. Orthogonal Initialization of Weights and Constant Initialization of biases (:material-github: [a2c/utils.py#L58)](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/a2c/utils.py#L58))
+3. Normalized Q Network (:material-github: [purejaxql/pqn_atari.py#L200](https://github.com/mttga/purejaxql/blob/2205ae5308134d2cedccd749074bff2871832dc8/purejaxql/pqn_atari.py#L200))
+4. Uses the RAdam Optimizer with the default epsilon parameter(:material-github: [purejaxql/pqn_atari.py#L362](https://github.com/mttga/purejaxql/blob/2205ae5308134d2cedccd749074bff2871832dc8/purejaxql/pqn_atari.py#L362))
+5. Adam Learning Rate Annealing (:material-github: [pqn2/pqn2.py#L133-L135](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/pqn2/pqn2.py#L133-L135))
+6. Q Lambda Returns (:material-github: [purejaxql/pqn_atari.py#L446](https://github.com/mttga/purejaxql/blob/2205ae5308134d2cedccd749074bff2871832dc8/purejaxql/pqn_atari.py#L446))
+7. Mini-batch Updates (:material-github: [pqn2/pqn2.py#L157-L166](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/pqn2/pqn2.py#L157-L166))
+8. Global Gradient Clipping (:material-github: [purejaxql/pqn_atari.py#L360](https://github.com/mttga/purejaxql/blob/2205ae5308134d2cedccd749074bff2871832dc8/purejaxql/pqn_atari.py#L360))
+
+### Experiment results
+
+To run benchmark experiments, see :material-github: [benchmark/pqn.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/pqn.sh). Specifically, execute the following command:
+
+``` title="benchmark/pqn.sh" linenums="1"
+--8<-- "benchmark/pqn.sh:0:6"
+```
+
+Episode Rewards:
+
+| Environment | CleanRL PQN |
+|------------------|-------------------|
+| CartPole-v1 | 408.14 ± 128.42 |
+| Acrobot-v1 | -93.71 ± 2.94 |
+| MountainCar-v0 | -200.00 ± 0.00 |
+
+Runtime:
+
+| Environment | CleanRL PQN |
+|------------------|----------------------|
+| CartPole-v1 | 3.619667511995135 |
+| Acrobot-v1 | 4.264845468334595 |
+| MountainCar-v0 | 3.99800178870078 |
+
+Learning curves:
+
+``` title="benchmark/pqn_plot.sh" linenums="1"
+--8<-- "benchmark/pqn_plot.sh:1:9"
+```
+
+
+
+Tracked experiments:
+
+
+
+## `pqn_atari_envpool.py`
+
+The [pqn_atari_envpool.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool.py) has the following features:
+
+* Uses the blazing fast [Envpool](https://github.com/sail-sg/envpool) vectorized environment.
+* For Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
+* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
+* Works with the `Discrete` action space
+
+???+ warning
+
+ Note that `pqn_atari_envpool.py` does not work in Windows :fontawesome-brands-windows: and MacOs :fontawesome-brands-apple:. See envpool's built wheels here: [https://pypi.org/project/envpool/#files](https://pypi.org/project/envpool/#files)
+
+???+ bug
+
+ EnvPool's vectorized environment **does not behave the same** as gym's vectorized environment, which causes a compatibility bug in our PQN implementation. When an action $a$ results in an episode termination or truncation, the environment generates $s_{last}$ as the terminated or truncated state; we then use $s_{new}$ to denote the initial state of the new episodes. Here is how the bahviors differ:
+
+ * Under the vectorized environment of `envpool<=0.6.4`, the `obs` in `obs, reward, done, info = env.step(action)` is the truncated state $s_{last}$
+ * Under the vectorized environment of `gym==0.23.1`, the `obs` in `obs, reward, done, info = env.step(action)` is the initial state $s_{new}$.
+
+ This causes the $s_{last}$ to be off by one.
+ See [:material-github: sail-sg/envpool#194](https://github.com/sail-sg/envpool/issues/194) for more detail. However, it does not seem to impact performance, so we take a note here and await for the upstream fix.
+
+
+### Usage
+
+=== "poetry"
+
+ ```bash
+ poetry install -E envpool
+ poetry run python cleanrl/pqn_atari_envpool.py --help
+ poetry run python cleanrl/pqn_atari_envpool.py --env-id Breakout-v5
+ ```
+
+=== "pip"
+
+ ```bash
+ pip install -r requirements/requirements-envpool.txt
+ python cleanrl/pqn_atari_envpool.py --help
+ python cleanrl/pqn_atari_envpool.py --env-id Breakout-v5
+ ```
+
+### Explanation of the logged metrics
+
+See [related docs](/rl-algorithms/pqn/#explanation-of-the-logged-metrics) for `pqn.py`.
+
+### Implementation details
+
+[pqn_atari_envpool.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool.py) uses a customized `RecordEpisodeStatistics` to work with envpool but has the same other implementation details as `ppo_atari.py`.
+
+### Experiment results
+
+To run benchmark experiments, see :material-github: [benchmark/pqn.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/pqn.sh). Specifically, execute the following command:
+
+``` title="benchmark/pqn.sh" linenums="1"
+--8<-- "benchmark/pqn.sh:12:17"
+```
+
+Episode Rewards:
+
+| Environment | CleanRL PQN |
+|-------------------|--------------------|
+| Breakout-v5 | 356.93 ± 7.48 |
+| SpaceInvaders-v5 | 900.07 ± 107.95 |
+| BeamRider-v5 | 1987.97 ± 24.47 |
+| Pong-v5 | 20.44 ± 0.11 |
+| MsPacman-v5 | 2437.57 ± 215.01 |
+
+Runtime:
+
+| Environment | CleanRL PQN |
+|-------------------|-----------------------|
+| Breakout-v5 | 41.27235000576079 |
+| SpaceInvaders-v5 | 42.191246278536035 |
+| BeamRider-v5 | 42.66799268151052 |
+| Pong-v5 | 39.35770012905844 |
+| MsPacman-v5 | 43.22808379473344 |
+
+
+Learning curves:
+
+``` title="benchmark/pqn_plot.sh" linenums="1"
+--8<-- "benchmark/pqn_plot.sh:11:29"
+```
+
+
+
+Tracked experiments:
+
+
+
+## `pqn_atari_envpool_lstm.py`
+
+The [pqn_atari_envpool_lstm.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool_lstm.py) has the following features:
+
+* Uses the blazing fast [Envpool](https://github.com/sail-sg/envpool) vectorized environment.
+* For Atari games using LSTM without stacked frames. It uses convolutional layers and common atari-based pre-processing techniques.
+* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
+* Works with the `Discrete` action space
+
+???+ warning
+
+ Note that `pqn_atari_envpool.py` does not work in Windows :fontawesome-brands-windows: and MacOs :fontawesome-brands-apple:. See envpool's built wheels here: [https://pypi.org/project/envpool/#files](https://pypi.org/project/envpool/#files)
+
+???+ bug
+
+ EnvPool's vectorized environment **does not behave the same** as gym's vectorized environment, which causes a compatibility bug in our PQN implementation. When an action $a$ results in an episode termination or truncation, the environment generates $s_{last}$ as the terminated or truncated state; we then use $s_{new}$ to denote the initial state of the new episodes. Here is how the bahviors differ:
+
+ * Under the vectorized environment of `envpool<=0.6.4`, the `obs` in `obs, reward, done, info = env.step(action)` is the truncated state $s_{last}$
+ * Under the vectorized environment of `gym==0.23.1`, the `obs` in `obs, reward, done, info = env.step(action)` is the initial state $s_{new}$.
+
+ This causes the $s_{last}$ to be off by one.
+ See [:material-github: sail-sg/envpool#194](https://github.com/sail-sg/envpool/issues/194) for more detail. However, it does not seem to impact performance, so we take a note here and await for the upstream fix.
+
+### Usage
+
+
+=== "poetry"
+
+ ```bash
+ poetry install -E atari
+ poetry run python cleanrl/pqn_atari_envpool_lstm.py --help
+ poetry run python cleanrl/pqn_atari_envpool_lstm.py --env-id Breakout-v5
+ ```
+
+=== "pip"
+
+ ```bash
+ pip install -r requirements/requirements-atari.txt
+ python cleanrl/pqn_atari_envpool_lstm.py --help
+ python cleanrl/pqn_atari_envpool_lstm.py --env-id Breakout-v5
+ ```
+
+
+### Explanation of the logged metrics
+
+See [related docs](/rl-algorithms/pqn/#explanation-of-the-logged-metrics) for `pqn.py`.
+
+### Implementation details
+
+[pqn_atari_envpool_lstm.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn_atari_envpool_lstm.py) is based on the "5 LSTM implementation details" in [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/pqn-implementation-details/), which are as follows:
+
+1. Layer initialization for LSTM layers (:material-github: [a2c/utils.py#L84-L86](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/a2c/utils.py#L84-L86))
+2. Initialize the LSTM states to be zeros (:material-github: [common/models.py#L179](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/models.py#L179))
+3. Reset LSTM states at the end of the episode (:material-github: [common/models.py#L141](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/models.py#L141))
+4. Prepare sequential rollouts in mini-batches (:material-github: [a2c/utils.py#L81](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/a2c/utils.py#L81))
+5. Reconstruct LSTM states during training (:material-github: [a2c/utils.py#L81](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/a2c/utils.py#L81))
+
+To help test out the memory, we remove the 4 stacked frames from the observation (i.e., using `env = gym.wrappers.FrameStack(env, 1)` instead of `env = gym.wrappers.FrameStack(env, 4)` like in `ppo_atari.py` )
+
+### Experiment results
+
+To run benchmark experiments, see :material-github: [benchmark/pqn.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/pqn.sh). Specifically, execute the following command:
+
+``` title="benchmark/pqn.sh" linenums="1"
+--8<-- "benchmark/pqn.sh:23:28"
+```
+
+
+Episode Rewards:
+
+| Environment | CleanRL PQN |
+|-------------------|--------------------|
+| Breakout-v5 | 366.47 ± 2.72 |
+| SpaceInvaders-v5 | 681.92 ± 40.15 |
+| BeamRider-v5 | 2050.85 ± 38.58 |
+| MsPacman-v5 | 1815.20 ± 183.03 |
+
+Runtime:
+
+| Environment | CleanRL PQN |
+|-------------------|-----------------------|
+| Breakout-v5 | 170.30230232607076 |
+| SpaceInvaders-v5 | 168.45747969698144 |
+| BeamRider-v5 | 172.11561139317593 |
+| MsPacman-v5 | 171.66131707108408 |
+
+Learning curves:
+
+``` title="benchmark/pqn_plot.sh" linenums="1"
+--8<-- "benchmark/pqn_plot.sh:32:50"
+```
+
+
+
+Tracked experiments:
+
+
\ No newline at end of file
diff --git a/docs/rl-algorithms/pqn/pqn.png b/docs/rl-algorithms/pqn/pqn.png
new file mode 100644
index 00000000..df72dbdb
Binary files /dev/null and b/docs/rl-algorithms/pqn/pqn.png differ
diff --git a/docs/rl-algorithms/pqn/pqn_lstm.png b/docs/rl-algorithms/pqn/pqn_lstm.png
new file mode 100644
index 00000000..b97946e3
Binary files /dev/null and b/docs/rl-algorithms/pqn/pqn_lstm.png differ
diff --git a/docs/rl-algorithms/pqn/pqn_state.png b/docs/rl-algorithms/pqn/pqn_state.png
new file mode 100644
index 00000000..743d0879
Binary files /dev/null and b/docs/rl-algorithms/pqn/pqn_state.png differ
diff --git a/mkdocs.yml b/mkdocs.yml
index 9c8c9769..6f0bd225 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -50,6 +50,7 @@ nav:
- rl-algorithms/rpo.md
- rl-algorithms/qdagger.md
- rl-algorithms/ppo-trxl.md
+ - rl-algorithms/pqn.md
- Advanced:
- advanced/hyperparameter-tuning.md
- advanced/resume-training.md