forked from cycraig/MP-DQN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_platform_qpamdp.py
140 lines (123 loc) · 6.42 KB
/
run_platform_qpamdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import click
import time
import numpy as np
import os
import gym
import gym_platform
from agents.qpamdp import QPAMDPAgent
from agents.sarsa_lambda import SarsaLambdaAgent
from common import get_calling_function_parameters
from common.wrappers import QPAMDPScaledParameterisedActionWrapper
from gym.wrappers import Monitor
from common.wrappers import ScaledStateWrapper
import wandb # Saving metrics
# Notifications
from pynotifier import NotificationClient, Notification
from pynotifier.backends import platform
def evaluate(env, agent, episodes=1000):
returns = []
timesteps = []
for _ in range(episodes):
state, _ = env.reset()
terminal = False
t = 0
total_reward = 0.
while not terminal:
t += 1
state = np.array(state, dtype=np.float32, copy=False)
action = agent.act(state)
(state, _), reward, terminal, _ = env.step(action)
total_reward += reward
timesteps.append(t)
returns.append(total_reward)
return np.array(returns)
@click.command()
@click.option('--seed', default=7, help='Random seed.', type=int)
@click.option('--random-seed', default=True, help='Automatically set random seed.', type=bool)
@click.option('--episodes', default=20000, help='Number of epsiodes.', type=int)
@click.option('--evaluation-episodes', default=100, help='Episodes over which to evaluate after training.', type=int)
@click.option('--parameter-rollouts', default=25, help='Number of rollouts per parameter update.', type=int) # default 50, 25 best # So I set it to 25?
@click.option('--scale', default=False, help='Scale inputs and actions.', type=bool)
@click.option('--initialise-params', default=True, help='Initialise action parameters.', type=bool)
@click.option('--save-dir', default="results/platform", help='Output directory.', type=str)
@click.option('--title', default="QPAMDP", help="Prefix of output files", type=str)
@click.option('--use-wandb', default=False, help="Use Weights & Biases for tracking metrics.", type=bool)
@click.option('--runs', default=1, help="How many times to run this config.", type=int)
@click.option('--initial-action-learning-episodes', default=10000, help="How many times to run this config.", type=int) # Is 10,000 too high?
def run(seed, random_seed, episodes, evaluation_episodes, parameter_rollouts, scale, initialise_params, save_dir, title, use_wandb, runs, initial_action_learning_episodes):
for run_index in range(runs):
if random_seed:
seed = np.random.randint(low=1, high=1000000)
if use_wandb:
run_config = get_calling_function_parameters()
run_config["algorithm"] = "QPAMDP"
run_config["environment"] = "Platform"
wandb.init(
project="bester-scripts",
tags=["QPAMDP", "Platform"],
config=run_config
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
alpha_param = 1.0
variances = [0.1, 0.1, 0.01]
initial_params = [3., 10., 400.]
env = gym.make('Platform-v0')
dir = os.path.join(save_dir, title)
if scale:
env = ScaledStateWrapper(env)
variances = [0.0001, 0.0001, 0.0001]
for a in range(env.action_space.spaces[0].n):
initial_params[a] = 2. * (initial_params[a] - env.action_space.spaces[1].spaces[a].low) / (
env.action_space.spaces[1].spaces[a].high - env.action_space.spaces[1].spaces[a].low) - 1.
env = QPAMDPScaledParameterisedActionWrapper(env)
alpha_param = 0.1
env = Monitor(env, directory=os.path.join(dir,str(seed)), video_callable=False, write_upon_reset=False, force=True)
env.seed(seed)
np.random.seed(seed)
act_obs_index = [0, 1, 2, 3]
param_obs_index = None
discrete_agent = SarsaLambdaAgent(env.observation_space.spaces[0], env.action_space.spaces[0], alpha=1.0,
gamma=0.999, temperature=1.0, cooling=0.995, lmbda=0.5, order=6,
scale_alpha=True, use_softmax=True, seed=seed,
observation_index=act_obs_index, gamma_step_adjust=True)
agent = QPAMDPAgent(env.observation_space.spaces[0], env.action_space, alpha=alpha_param,
initial_action_learning_episodes=initial_action_learning_episodes, seed=seed, action_obs_index=act_obs_index,
parameter_obs_index=param_obs_index, action_relearn_episodes=1000, variances=variances,
parameter_updates=180, parameter_rollouts=parameter_rollouts, norm_grad=False,
discrete_agent=discrete_agent, print_freq=100, use_wandb=use_wandb)
agent.discrete_agent.gamma_step_adjust = True
if initialise_params:
for a in range(env.action_space.spaces[0].n):
agent.parameter_weights[a][0,0] = initial_params[a]
max_steps = 250 # 201 <-- ?
start_time = time.time()
agent.learn(env, episodes, max_steps)
end_time = time.time()
print("Training took %.2f seconds" % (end_time - start_time))
env.close()
returns = env.get_episode_rewards()
print("Ave. return =", sum(returns) / len(returns))
print("Ave. last 100 episode return =", sum(returns[-100:]) / 100.)
np.save(os.path.join(dir, title + "{}".format(str(seed))), returns)
if evaluation_episodes > 0:
print("Evaluating agent over {} episodes".format(evaluation_episodes))
agent.variances = 0
agent.discrete_agent.epsilon = 0.
agent.discrete_agent.temperature = 0.
evaluation_returns = evaluate(env, agent, evaluation_episodes)
print("Ave. evaluation return =", sum(evaluation_returns) / len(evaluation_returns))
print("Ave. evaluation prob. =", sum(evaluation_returns == 50.) / len(evaluation_returns))
np.save(os.path.join(dir, title + "{}e".format(str(seed))), evaluation_returns)
if use_wandb:
wandb.finish()
c = NotificationClient()
c.register_backend(platform.Backend())
notification = Notification(
title='Run {} complete ({}).'.format(run_index, os.path.basename(__file__)),
duration=5,
)
c.notify_all(notification)
if __name__ == '__main__':
run()