-
Notifications
You must be signed in to change notification settings - Fork 0
/
play.py
88 lines (71 loc) · 2.32 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
import time
from utils import resize_image, XboxController, get_frame
from termcolor import cprint
import serial
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
if sys.argv[2] == "1":
from train3 import create_model_1 as create_model
if sys.argv[2] == "2":
from train3 import create_model_2 as create_model
# Play
class Actor(object):
def __init__(self, model):
# Load in model from train.py and load in the trained weights
self.model = create_model(keep_prob=1) # no dropout
self.model.load_weights(model)
self.real_controller = None
def get_action(self, obs):
manual_override = False
if not manual_override:
vec = resize_image(obs)
vec = np.expand_dims(vec, axis=0)
joystick = self.model.predict(vec, batch_size=1)[0]
else:
joystick = self.real_controller.read()
joystick[1] *= -1 # flip y (this is in the config when it runs normally)
output = [
int(joystick[0] * 70),
int(0),
int(round(1)),
int(round(0)),
int(round(0)),
]
if manual_override:
cprint("Manual: " + str(output), 'yellow')
else:
cprint("AI: " + str(joystick[0]), 'green')
return output
if __name__ == '__main__':
ser = serial.Serial(sys.argv[1], 115200)
cap = cv2.VideoCapture(0)
#Check whether user selected camera is opened successfully.
if not (cap.isOpened()):
print("Could not open video device")
ret, frame = cap.read()
print('env ready!')
actor = Actor(sys.argv[3])
print('actor ready!')
print('beginning episode loop')
total_reward = 0
end_episode = False
control = True
plt.ion()
plt.figure('viewer', figsize=(16, 6))
while not end_episode:
frame = get_frame(cap)
frame = frame[:,:,::-1]
action = actor.get_action(frame)
if control:
ser.write(b"\xFF")
t = time.time()
ti = int(t)
if(t*1000 - (ti*1000) < 500):
ser.write(bytes([0b10000000]))
else:
ser.write(bytes([0b10000000]))
ser.write(bytes([(int(action[0]) + 128)]))