-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathexample_discrete.html
113 lines (112 loc) · 3.98 KB
/
example_discrete.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ppo-tfjs (discrete)</title>
</head>
<body>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="ppo.js"></script>
<script>
const canvas = document.createElement('canvas')
canvas.width = 500
canvas.height = 500
document.body.appendChild(canvas)
const ctx = canvas.getContext('2d')
const updateFrequency = 1
class Env {
constructor() {
this.actionSpace = {
'class': 'Discrete',
'n': 4,
}
this.observationSpace = {
'class': 'Box',
'shape': [4],
'dtype': 'float32'
}
this.resets = 0
}
async step(action) {
if (Array.isArray(action)) {
action = action[0]
}
const oldAgent = this.agent.slice(0)
switch (action) {
case 0:
this.agent[0] -= 0.02
break
case 1:
this.agent[0] += 0.02
break
case 2:
this.agent[1] -= 0.02
break
case 3:
this.agent[1] += 0.02
break
}
this.i += 1
const reward = -Math.sqrt(
(this.agent[0] - this.goal[0]) * (this.agent[0] - this.goal[0]) +
(this.agent[1] - this.goal[1]) * (this.agent[1] - this.goal[1])
)
const done = this.i > 64 || reward > -0.01
if (reward > -0.01) {
console.log('Goal reached:', reward)
}
// draw line and point
if (this.resets % updateFrequency === 0) {
ctx.fillStyle = 'blue'
ctx.fillRect(this.agent[0] * 500, this.agent[1] * 500, 2, 2)
ctx.beginPath()
ctx.moveTo(oldAgent[0] * 500, oldAgent[1] * 500)
ctx.lineTo(this.agent[0] * 500, this.agent[1] * 500)
ctx.stroke()
await new Promise(resolve => setTimeout(resolve, 10))
}
return [
[this.agent[0], this.agent[1], this.goal[0], this.goal[1]],
reward,
done
]
}
reset() {
this.agent = [
Math.random(),
Math.random(),
]
this.goal = [
Math.random(),
Math.random(),
]
this.i = 0
this.resets += 1
if (this.resets % updateFrequency === 0) {
ctx.clearRect(0, 0, 500, 500)
ctx.fillStyle = 'red'
ctx.fillRect(this.goal[0] * 500, this.goal[1] * 500, 10, 10)
ctx.fillStyle = 'blue'
ctx.fillRect(this.agent[0] * 500, this.agent[1] * 500, 10, 10)
}
return [this.agent[0], this.agent[1], this.goal[0], this.goal[1]]
}
}
const env = new Env()
const ppo = new PPO(env, {'nSteps': 1024, 'nEpochs': 50})
console.log(ppo.env)
;(async () => {
await ppo.learn({
'totalTimesteps': 100000,
'callback': {
'onRolloutEnd': async (_) => {
console.log('N tensors', tf.memory().numTensors)
}
}
})
})()
</script>
</body>
</html>