-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmeta_train.py
38 lines (27 loc) · 979 Bytes
/
meta_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gymnasium as gym
from wrappers import FrankaObservationWrapper
from meta_agent import MetaAgent
if __name__ == '__main__':
env_name="FrankaKitchen-v1"
max_episode_steps=1500
replay_buffer_size = 1000000
tasks = ['top burner', 'microwave', 'hinge cabinet']
gamma = 0.99
tau = 0.005
alpha = 0.1
target_update_interval = 1
updates_per_step = 4
hidden_size = 512
learning_rate = 0.00005
batch_size = 64
episodes = 3000
env = gym.make(env_name, max_episode_steps=max_episode_steps, tasks_to_complete=tasks)
env = FrankaObservationWrapper(env)
obs, info = env.reset()
obs_size = obs.shape[0]
meta = MetaAgent(env, tasks, max_episode_steps=max_episode_steps)
meta.initialize_memory(augment_data=True, augment_rewards=True, augment_noise_ratio=0.1)
meta.initialize_agents()
meta.train(episodes=episodes, summary_writer_name=f'meta_agent')
meta.save_models()
env.close()