-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
81 lines (57 loc) · 2.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
def train(env, agent, n_episodes = 1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
'''Train the agent on environment'''
solved = False
# getting brain name
brain_name = env.brain_names[0]
# initializing scores over all episodes
scores = []
# initializing epsilon
eps = eps_start
# for each episode:
for i_episode in range(1, n_episodes + 1):
# resetting environment
env_info = env.reset(train_mode=True)[brain_name]
# getting initial state
frame_0 = env_info.vector_observations[0]
#state = np.expand_dims(frame_0, 0)
state = np.zeros(shape=(len(frame_0), 4))
state[:,0] = frame_0
# moving right 4 frames to build initial temporal state
for j in range(1, 4):
next_env_info = env.step(3)[brain_name]
next_frame = next_env_info.vector_observations[0]
state[:, j] = next_frame
# initializing episode score
score = 0
while True:
# getting action from agent
action = agent.act(state.flatten('F'), eps)#.flatten('F'), eps)
# taking action
next_env_info = env.step(int(action))[brain_name]
# retriving env_info
next_frame = next_env_info.vector_observations[0]
# building next state
next_state = np.c_[state[:, 1:4], next_frame.reshape((-1, 1))]
# retrieving reward and done
reward = next_env_info.rewards[0]
done = next_env_info.local_done[0]
# stepping agent forward
agent.step(state.flatten('F'), action, reward, next_state.flatten('F'), done)
score += reward
# breaking if end of episode
if done:
break
state = next_state
# decaying epsilon
eps = max(eps_end, eps_decay * eps)
scores.append(score)
# printing status
if i_episode % 10 == 0:
print("Episode", i_episode - 10, "to", i_episode, "scores: ",
" ".join([ "%02d" % s for s in scores[-10:]]))
# print episode where environment is solved
if np.mean(scores[-100:]) > 13 and not solved:
solved = True
print("Solved on episode: " + str(i_episode))
return scores