You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In, for example, dqn_atari.py the replay buffer is instantiated with the optimize_memory_usage=True flag. This makes the buffer only have one stored list for observations, and chooses next_obs=observations[i+1] when sampling. However, cleanrl does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]). But optimize_memory_usage means that this change is not reflected in the stored/sampled data.
Instead of data.next_observation[i] being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.
Expected Behavior
It should be the correct next observation.
Possible Solution
I'm guessing there's a way to make this work, but for now the easiest thing to do is set optimize_memory_usage to False.
Steps to Reproduce
Here's a minimal code example, where the important parts are directly cribbed from dqn_atari.py. Switching to optimize_memory_usage=False prevents the assertion error.
import gymnasium as gym
from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3 as sb3
import numpy as np
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
return thunk
envs = gym.vector.SyncVectorEnv(
[make_env("MountainCar-v0", i, i, False, "testing") for i in [0]]
)
obs, _ = envs.reset(seed=0)
rb = ReplayBuffer(
1000,
envs.single_observation_space,
envs.single_action_space,
"cpu",
optimize_memory_usage=True,
# optimize_memory_usage=False,
handle_timeout_termination=False,
)
seen_obs_and_next = set()
for i in range(1000):
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
for o, next_o in zip(obs, real_next_obs): # because vectorized env
seen_obs_and_next.add( (tuple(o.tolist()), tuple(next_o.tolist())) )
data = rb.sample(10000)
for i in range(10000):
o = data.observations[i]
no = data.next_observations[i]
assert (tuple(o.tolist()), tuple(no.tolist())) in seen_obs_and_next
The text was updated successfully, but these errors were encountered:
samlobel
changed the title
Another truncation bug
Truncation not handled correctly when optimize_memory_usage=TrueApr 26, 2024
Problem Description
In, for example,
dqn_atari.py
the replay buffer is instantiated with theoptimize_memory_usage=True
flag. This makes the buffer only have one stored list for observations, and choosesnext_obs=observations[i+1]
when sampling. However,cleanrl
does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]
). Butoptimize_memory_usage
means that this change is not reflected in the stored/sampled data.Checklist
poetry install
(see CleanRL's installation guideline.Current Behavior
Instead of
data.next_observation[i]
being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.Expected Behavior
It should be the correct next observation.
Possible Solution
I'm guessing there's a way to make this work, but for now the easiest thing to do is set
optimize_memory_usage
to False.Steps to Reproduce
Here's a minimal code example, where the important parts are directly cribbed from
dqn_atari.py
. Switching tooptimize_memory_usage=False
prevents the assertion error.The text was updated successfully, but these errors were encountered: