-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSaveLoadDisplay.py
115 lines (91 loc) · 3.42 KB
/
SaveLoadDisplay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pickle
import main
import neat
def show(genomes, config):
nets = []
gen = []
rockets = []
for _, g in genomes:
net = neat.nn.FeedForwardNetwork.create(g, config)
nets.append(net)
rockets.append(main.Rocket(100, 200))
g.fitness = 0
gen.append(g)
bg = main.BGMove(0)
asteroids = [main.Asteroid(800)]
win = main.pygame.display.set_mode((main.WIDTH, main.HEIGHT))
score = 0
clock = main.pygame.time.Clock()
run = True
while run:
clock.tick(60)
for event in main.pygame.event.get():
if event.type == main.pygame.QUIT:
run = False
main.pygame.quit()
quit()
asteroid_index = 0
if len(rockets) > 0:
if len(asteroids) > 1 and rockets[0].x > asteroids[0].xcoord + asteroids[0].obj.get_width():
asteroid_index = 1
else:
run = False
break
for x, rocket in enumerate(rockets):
gen[x].fitness += 0.1
rocket.move()
net_output = nets[rockets.index(rocket)].activate((rocket.y, abs(rocket.y - asteroids[asteroid_index].ycoord), abs(rocket.y - asteroids[asteroid_index].xcoord)))
t_softmax_result = main.t_softmax(net_output)
class_output = main.np.argmax(((t_softmax_result / main.torch.max(t_softmax_result)) == 1))
if class_output == 0:
rocket.flyup()
if class_output == 1:
rocket.flydown()
if class_output == 2:
rocket.stay()
add_asteroid = False
remove = []
for asteroid in asteroids:
for x, rocket in enumerate(rockets):
if asteroid.collide(rocket):
gen[x].fitness -= 5
rockets.pop(x)
nets.pop(x)
gen.pop(x)
if not asteroid.passed and asteroid.xcoord < rocket.x:
asteroid.passed = True
add_asteroid = True
if asteroid.xcoord + asteroid.obj.get_width() < 0:
remove.append(asteroid)
asteroid.move()
if add_asteroid:
score += 1
print(score)
for g in gen:
g.fitness += 0.5
asteroids.append(main.Asteroid(800))
for r in remove:
asteroids.remove(r)
for x, rocket in enumerate(rockets):
if rocket.y + rocket.img.get_height() >= main.HEIGHT or rocket.y < 0:
rockets.pop(x)
nets.pop(x)
gen.pop(x)
bg.move()
main.draw_window(win, rockets, asteroids, bg, score)
def runSaveLoad(config_path):
path = main.os.path.join("exported_pkls", "win.pkl")
# Load the configuration again
config = neat.config.Config(main.neat.DefaultGenome, main.neat.DefaultReproduction, main.neat.DefaultSpeciesSet,
main.neat.DefaultStagnation, config_path)
# Open the pickle file again
with open(path, "rb") as f:
genome = pickle.load(f)
# Create a list with the first item being the loaded genome
genomes = [(1, genome)]
# With this genome, create the NN again
show(genomes, config)
if __name__ == "__main__":
local_dir = main.os.path.dirname(__file__)
config_path = main.os.path.join(local_dir, "config.txt")
runSaveLoad(config_path)