-
Notifications
You must be signed in to change notification settings - Fork 48
/
async_drq_randomized.py
397 lines (323 loc) · 13.4 KB
/
async_drq_randomized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
#!/usr/bin/env python3
import time
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import tqdm
from absl import app, flags
from flax.training import checkpoints
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
from serl_launcher.utils.timer_utils import Timer
from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.utils.train_utils import concat_batches
from agentlace.trainer import TrainerServer, TrainerClient
from agentlace.data.data_store import QueuedDataStore
from serl_launcher.utils.launcher import (
make_drq_agent,
make_trainer_config,
make_wandb_logger,
)
from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore
from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper
from franka_env.envs.relative_env import RelativeFrame
from franka_env.envs.wrappers import (
GripperCloseEnv,
SpacemouseIntervention,
Quat2EulerWrapper,
)
import franka_env
FLAGS = flags.FLAGS
flags.DEFINE_string("env", "FrankaEnv-Vision-v0", "Name of environment.")
flags.DEFINE_string("agent", "drq", "Name of agent.")
flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.")
flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_bool("save_model", False, "Whether to save model.")
flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.")
flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.")
flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.")
flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 30, "Number of steps per update the server.")
flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
# flag to indicate if this is a leaner or a actor
flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.")
flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.")
flags.DEFINE_boolean("render", False, "Render the environment.")
flags.DEFINE_string("ip", "localhost", "IP address of the learner.")
# "small" is a 4 layer convnet, "resnet" and "mobilenet" are frozen with pretrained weights
flags.DEFINE_string("encoder_type", "resnet-pretrained", "Encoder type.")
flags.DEFINE_string("demo_path", None, "Path to the demo data.")
flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.")
flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.")
flags.DEFINE_integer(
"eval_checkpoint_step", 0, "evaluate the policy from ckpt at this step"
)
flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.")
flags.DEFINE_boolean(
"debug", False, "Debug mode."
) # debug mode will disable wandb logging
devices = jax.local_devices()
num_devices = len(devices)
sharding = jax.sharding.PositionalSharding(devices)
def print_green(x):
return print("\033[92m {}\033[00m".format(x))
##############################################################################
def actor(agent: DrQAgent, data_store, env, sampling_rng):
"""
This is the actor loop, which runs when "--actor" is set to True.
"""
if FLAGS.eval_checkpoint_step:
success_counter = 0
time_list = []
ckpt = checkpoints.restore_checkpoint(
FLAGS.checkpoint_path,
agent.state,
step=FLAGS.eval_checkpoint_step,
)
agent = agent.replace(state=ckpt)
for episode in range(FLAGS.eval_n_trajs):
obs, _ = env.reset()
done = False
start_time = time.time()
while not done:
actions = agent.sample_actions(
observations=jax.device_put(obs),
argmax=True,
)
actions = np.asarray(jax.device_get(actions))
next_obs, reward, done, truncated, info = env.step(actions)
obs = next_obs
if done:
if reward:
dt = time.time() - start_time
time_list.append(dt)
print(dt)
success_counter += reward
print(reward)
print(f"{success_counter}/{episode + 1}")
print(f"success rate: {success_counter / FLAGS.eval_n_trajs}")
print(f"average time: {np.mean(time_list)}")
return # after done eval, return and exit
client = TrainerClient(
"actor_env",
FLAGS.ip,
make_trainer_config(),
data_store,
wait_for_server=True,
)
# Function to update the agent with new params
def update_params(params):
nonlocal agent
agent = agent.replace(state=agent.state.replace(params=params))
client.recv_network_callback(update_params)
obs, _ = env.reset()
done = False
# training loop
timer = Timer()
running_return = 0.0
for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True):
timer.tick("total")
with timer.context("sample_actions"):
if step < FLAGS.random_steps:
actions = env.action_space.sample()
else:
sampling_rng, key = jax.random.split(sampling_rng)
actions = agent.sample_actions(
observations=jax.device_put(obs),
seed=key,
deterministic=False,
)
actions = np.asarray(jax.device_get(actions))
# Step environment
with timer.context("step_env"):
next_obs, reward, done, truncated, info = env.step(actions)
# override the action with the intervention action
if "intervene_action" in info:
actions = info.pop("intervene_action")
reward = np.asarray(reward, dtype=np.float32)
info = np.asarray(info)
running_return += reward
transition = dict(
observations=obs,
actions=actions,
next_observations=next_obs,
rewards=reward,
masks=1.0 - done,
dones=done,
)
data_store.insert(transition)
obs = next_obs
if done or truncated:
stats = {"train": info} # send stats to the learner to log
client.request("send-stats", stats)
running_return = 0.0
obs, _ = env.reset()
if step % FLAGS.steps_per_update == 0:
client.update()
timer.tock("total")
if step % FLAGS.log_period == 0:
stats = {"timer": timer.get_average_times()}
client.request("send-stats", stats)
##############################################################################
def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
# To track the step in the training loop
update_steps = 0
def stats_callback(type: str, payload: dict) -> dict:
"""Callback for when server receives stats request."""
assert type == "send-stats", f"Invalid request type: {type}"
if wandb_logger is not None:
wandb_logger.log(payload, step=update_steps)
return {} # not expecting a response
# Create server
server = TrainerServer(make_trainer_config(), request_callback=stats_callback)
server.register_data_store("actor_env", replay_buffer)
server.start(threaded=True)
# Loop to wait until replay_buffer is filled
pbar = tqdm.tqdm(
total=FLAGS.training_starts,
initial=len(replay_buffer),
desc="Filling up replay buffer",
position=0,
leave=True,
)
while len(replay_buffer) < FLAGS.training_starts:
pbar.update(len(replay_buffer) - pbar.n) # Update progress bar
time.sleep(1)
pbar.update(len(replay_buffer) - pbar.n) # Update progress bar
pbar.close()
# send the initial network to the actor
server.publish_network(agent.state.params)
print_green("sent initial network to actor")
# 50/50 sampling from RLPD, half from demo and half from online experience
replay_iterator = replay_buffer.get_iterator(
sample_args={
"batch_size": FLAGS.batch_size // 2,
"pack_obs_and_next_obs": True,
},
device=sharding.replicate(),
)
demo_iterator = demo_buffer.get_iterator(
sample_args={
"batch_size": FLAGS.batch_size // 2,
"pack_obs_and_next_obs": True,
},
device=sharding.replicate(),
)
# wait till the replay buffer is filled with enough data
timer = Timer()
for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"):
# run n-1 critic updates and 1 critic + actor update.
# This makes training on GPU faster by reducing the large batch transfer time from CPU to GPU
for critic_step in range(FLAGS.critic_actor_ratio - 1):
with timer.context("sample_replay_buffer"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
with timer.context("train_critics"):
agent, critics_info = agent.update_critics(
batch,
)
with timer.context("train"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)
if update_steps % FLAGS.log_period == 0 and wandb_logger:
wandb_logger.log(update_info, step=update_steps)
wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps)
if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0:
assert FLAGS.checkpoint_path is not None
checkpoints.save_checkpoint(
FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100
)
update_steps += 1
##############################################################################
def main(_):
assert FLAGS.batch_size % num_devices == 0
# seed
rng = jax.random.PRNGKey(FLAGS.seed)
# create env and load dataset
env = gym.make(
FLAGS.env,
fake_env=FLAGS.learner,
save_video=FLAGS.eval_checkpoint_step,
)
env = GripperCloseEnv(env)
if FLAGS.actor:
env = SpacemouseIntervention(env)
env = RelativeFrame(env)
env = Quat2EulerWrapper(env)
env = SERLObsWrapper(env)
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
env = RecordEpisodeStatistics(env)
image_keys = [key for key in env.observation_space.keys() if key != "state"]
rng, sampling_rng = jax.random.split(rng)
agent: DrQAgent = make_drq_agent(
seed=FLAGS.seed,
sample_obs=env.observation_space.sample(),
sample_action=env.action_space.sample(),
image_keys=image_keys,
encoder_type=FLAGS.encoder_type,
)
# replicate agent across devices
# need the jnp.array to avoid a bug where device_put doesn't recognize primitives
agent: DrQAgent = jax.device_put(
jax.tree_map(jnp.array, agent), sharding.replicate()
)
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=10000,
image_keys=image_keys,
)
import pickle as pkl
with open(FLAGS.demo_path, "rb") as f:
trajs = pkl.load(f)
for traj in trajs:
demo_buffer.insert(traj)
print(f"demo buffer size: {len(demo_buffer)}")
# learner loop
print_green("starting learner loop")
learner(
sampling_rng,
agent,
replay_buffer,
demo_buffer=demo_buffer,
)
elif FLAGS.actor:
sampling_rng = jax.device_put(sampling_rng, sharding.replicate())
data_store = QueuedDataStore(2000) # the queue size on the actor
# actor loop
print_green("starting actor loop")
actor(agent, data_store, env, sampling_rng)
else:
raise NotImplementedError("Must be either a learner or an actor")
if __name__ == "__main__":
app.run(main)