Skip to content

Commit

Permalink
make deterministic queue
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Aug 17, 2023
1 parent 6367b2c commit 8e20679
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 38 deletions.
39 changes: 20 additions & 19 deletions cleanba/cleanba_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,31 +644,31 @@ def update_minibatch(agent_state, minibatch):
devices=global_learner_decices,
)

rollout_queue = queue.Queue(maxsize=args.num_actor_threads)
params_queues = []
rollout_queues = []
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None

unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queue = queue.Queue(maxsize=1)
params_queue.put(device_params)
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queue,
params_queue,
rollout_queues[-1],
params_queues[-1],
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
local_devices[d_id],
),
).start()
params_queues.append(params_queue)

rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
Expand All @@ -677,16 +677,17 @@ def update_minibatch(agent_state, minibatch):
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
for _ in range(args.num_actor_threads * len(args.actor_device_ids)):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queue.get()
sharded_storages.append(sharded_storage)
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, learner_keys) = multi_device_update(
Expand All @@ -709,8 +710,8 @@ def update_minibatch(agent_state, minibatch):
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
Expand All @@ -722,7 +723,7 @@ def update_minibatch(agent_state, minibatch):
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/loss", loss.item(), global_step)
if update >= args.num_updates:
if learner_policy_version >= args.num_updates:
break

if args.save_model and args.local_rank == 0:
Expand Down
39 changes: 20 additions & 19 deletions cleanba/cleanba_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,31 +628,31 @@ def update_minibatch(agent_state, minibatch):
devices=global_learner_decices,
)

rollout_queue = queue.Queue(maxsize=args.num_actor_threads)
params_queues = []
rollout_queues = []
dummy_writer = SimpleNamespace()
dummy_writer.add_scalar = lambda x, y, z: None

unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
for thread_id in range(args.num_actor_threads):
params_queue = queue.Queue(maxsize=1)
params_queue.put(device_params)
params_queues.append(queue.Queue(maxsize=1))
rollout_queues.append(queue.Queue(maxsize=1))
params_queues[-1].put(device_params)
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queue,
params_queue,
rollout_queues[-1],
params_queues[-1],
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
d_idx * args.num_actor_threads + thread_id,
local_devices[d_id],
),
).start()
params_queues.append(params_queue)

rollout_queue_get_time = deque(maxlen=10)
data_transfer_time = deque(maxlen=10)
Expand All @@ -661,16 +661,17 @@ def update_minibatch(agent_state, minibatch):
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
for _ in range(args.num_actor_threads * len(args.actor_device_ids)):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queue.get()
sharded_storages.append(sharded_storage)
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
Expand All @@ -693,8 +694,8 @@ def update_minibatch(agent_state, minibatch):
global_step,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step)
writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step)
print(
global_step,
f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
Expand All @@ -707,7 +708,7 @@ def update_minibatch(agent_state, minibatch):
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/loss", loss.item(), global_step)
if update >= args.num_updates:
if learner_policy_version >= args.num_updates:
break

if args.save_model and args.local_rank == 0:
Expand Down

0 comments on commit 8e20679

Please sign in to comment.