-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
74 lines (52 loc) · 1.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import tensorflow as tf
import time
def make_session(num_cpu=4):
"""Returns a session that will use <num_cpu> CPU's only"""
gpu_options = tf.GPUOptions(allow_growth=True)
tf_config = tf.ConfigProto(
inter_op_parallelism_threads=num_cpu,
gpu_options=gpu_options)
return tf.Session(config=tf_config)
def initialize_all_vars():
tf.get_default_session().run(tf.global_variables_initializer())
def set_seed(seed):
tf.set_random_seed(seed + 1)
np.random.seed(seed + 2)
def eval_policy_cartpole(env, alg, ep_num=10, gamma=None, prt=False, save_data=False, POMDP=False):
rew_list = []
obs_list = []
act_list = []
ep_rews = []
undiscounted = []
assert hasattr(alg, 'sample_action')
for i in range(ep_num):
if prt:
if i > 0:
print('Traj {} {}'.format(i, ep_rews[-1]))
obs = env.reset()
done = False
factor = 1.0
ep_rew = 0.0
undiscounted.append(0.0)
while not done:
act = np.squeeze(alg.sample_action([obs]))
obs_list.append(obs)
act_list.append(act)
obs, rew, done, _ = env.step(act)
undiscounted[-1] += rew
rew *= factor
factor *= gamma
ep_rew += rew
rew_list.append(rew)
ep_rews.append(ep_rew)
print(np.mean(undiscounted))
if save_data:
return np.mean(ep_rews), np.array(obs_list), np.array(act_list).reshape([-1, 1])
else:
return np.mean(ep_rews), ep_rews
def get_percentile(data):
ptr = []
for i in range(10):
ptr.append(np.percentile(data, i * 10 + 5))
print(ptr)