-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
192 lines (167 loc) · 7.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from PIL import Image
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import time
import sys
import os
import sklearn.manifold
from classes.ga.ga import GA
from classes.pso.pso import PSO
from classes.operators import selection, replacement, crossover, velocity_update, topology
# Set windows properties
cv.namedWindow('Result')
# Params
SAMPLE = 'mona_lisa.jpg'
ALGORITHM = GA # GA or PSO
INTERPOLATION_SIZE = 5 # Number of interpolated frame to save for PSO results. Set to 1 to disable interpolation
VIDEO_INIT_GEN, VIDEO_FRAME_GEN = 3000, 500 # Number of generations to run for the first and for the other frames, respectively
if len(sys.argv) > 1:
SAMPLE = sys.argv[1]
ALGORITHM = PSO if sys.argv[2] == 'PSO' else GA
sample_name, sample_ext = SAMPLE.split('.')
# Load image or video
if sample_ext in ['jpg', 'jpeg', 'png']:
isvideo = False
img = cv.cvtColor(np.array(Image.open(f'samples/{SAMPLE}')), cv.COLOR_RGB2BGR)
fps = 30
elif sample_ext in ['mp4', 'gif']:
isvideo = True
video = cv.VideoCapture(f'samples/{SAMPLE}')
_, img = video.read()
frame_count = 0
fps = video.get(cv.CAP_PROP_FPS)
else:
raise ValueError(f'File extension "{sample_ext}" not supported')
# Prepare to save result as video
fourcc = cv.VideoWriter_fourcc(*'mp4v')
os.makedirs(f'results/videos', exist_ok=True)
out = cv.VideoWriter(f'results/videos/{ALGORITHM.__name__}_{sample_name}.mp4', fourcc, fps, img.shape[:2][::-1])
# Genetic algorithm
ga = GA(
img,
pop_size=100,
n_poly=100,
n_vertex=3,
random_init_color=False, # If False, initialize the color of each polygon based on the average pixel values of the target image. If False, the intial colors are assigned randomly.
selection_strategy=selection.TruncatedSelection(.1), # selection.RouletteWheelSelection(), selection.RankBasedSelection(), selection.TruncatedSelection(.1), selection.TournamentSelection(10)
replacement_strategy=replacement.CommaReplacement(), # replacement.CommaReplacement(), replacement.PlusReplacement(), replacement.CrowdingReplacement(4)
crossover_type=crossover.UniformCrossover(), # crossover.OnePointCrossover(), crossover.UniformCrossover(), crossover.ArithmeticCrossover()
self_adaptive=False, # Self-adaptetion of mutation step-sizes
mutation_rates=(0.02, 0.02, 0.02), # If self_adaptive is True, not used
mutation_step_sizes=(0.1, 0.1, 0.1) # If self_adaptive is True, not used
)
# Particle swarm optimization
pso = PSO(
img,
swarm_size=500,
line_length=20,
velocity_update_rule=velocity_update.Standard(), # velocity_update.Standard(), velocity_update.FullyInformed(), velocity_update.ComprehensiveLearning()
neighborhood_topology=topology.StarTopology(), # topology.DistanceTopology(), topology.RingTopology(), topology.StarTopology()
neighborhood_size=3,
coeffs=(0.2, 1.5, 1.5), # Inertia (0.7 - 0.8), cognitive coeff/social coeff (1.5 - 1.7)
min_distance=2,
max_velocity=50
)
fbest, favg, fworst = [], [], []
diversities, dist, prev_npswarm = [], None, None
try: # Press ctrl+c to exit loop
print(f'\nRunning {ALGORITHM.__name__} algorithm over "{SAMPLE}".\nPress ctrl+c to terminate the execution.\n')
while True:
start_time = time.time()
# Compute next generation
additional_info = ''
if ALGORITHM is GA:
gen, population = ga.next()
best = population[0]
additional_info = f' ({best.fitness_perc * 100:.2f}%), polygons: {best.n_poly}'
fitness = best.fitness
fbest.append(best.fitness)
favg.append(np.mean([i.fitness for i in population]))
fworst.append(population[-1].fitness)
if gen % 20 == 0: # Measure diversity every 20 generations
dist = ga.diversity()
diversity = dist.sum()
diversities.append(diversity)
additional_info += f', diversity: {int(diversity)}'
elif ALGORITHM is PSO:
gen, fitness = pso.next()
fbest.append(fitness)
# Print and save result
tot_time = round((time.time() - start_time)*1000)
print(f'{gen:04d}) {tot_time:04d}ms, fitness: {fitness:.4f}{additional_info}')
# Obtain current best solution
if ALGORITHM is GA:
best_img = best.draw()
elif ALGORITHM is PSO:
best_img = pso.draw()
# Show current best
result = cv.resize(best_img, img.shape[1::-1])
result = np.hstack([img, result])
result = cv.resize(result, None, fx=.6, fy=.6)
cv.imshow('Result', result)
# Save result in video
if (ALGORITHM is GA and (
(isvideo and gen % VIDEO_FRAME_GEN == 0) or
(not isvideo and gen % 10 == 0))
) or (ALGORITHM is PSO and (
(isvideo and gen % VIDEO_FRAME_GEN == 0) or
(not isvideo))
):
if ALGORITHM is GA:
frames = [best_img.copy()]
elif ALGORITHM is PSO:
if prev_npswarm is not None and not isvideo:
frames = pso.draw_interpolated(prev_npswarm, INTERPOLATION_SIZE) # Interpolate frames for better visualization
else:
frames = [best_img.copy()]
prev_npswarm = pso.npswarm
for frame in frames:
# frame = cv.putText(frame, f'{gen}', (2, 16), cv.FONT_HERSHEY_PLAIN, 1.4, (0, 0, 255), 2) # Print generation number
out.write(frame)
# Key press
key = cv.waitKey(1) & 0xFF
if key == ord(' '):
cv.waitKey(0)
# Update the target, in case of video input
if isvideo and ((frame_count == 0 and gen > VIDEO_INIT_GEN) or (frame_count > 0 and gen % VIDEO_FRAME_GEN == 0)): # Optimize over new frame every 100 generations. First frame used for 1000 generations
ret, img = video.read()
if not ret:
break
frame_count += 1
if ALGORITHM is GA:
ga.update_target(img)
elif ALGORITHM is PSO:
pso.update_target(img)
except KeyboardInterrupt:
pass
# Save final individual image
os.makedirs(f'results/images', exist_ok=True)
cv.imwrite(f'results/images/{ALGORITHM.__name__}_{sample_name}.jpg', best_img)
# Clear all
cv.destroyAllWindows()
out.release()
# Plots
# Fitness plots
fig, ax = plt.subplots()
fig.suptitle('Fitness trends')
x = range(len(fbest))
ax.plot(x, fbest, c='r', label='best')
if len(favg) > 0:
ax.plot(x, favg, c='b', label='average')
if len(fworst) > 0:
ax.plot(x, fworst, c='g', label='worst')
ax.legend()
# Diversity plots
if len(diversities) > 0:
fig, ax = plt.subplots()
fig.suptitle('Diversity')
ax.plot(range(len(diversities)), diversities, c='b', label='diversity')
ax.legend()
if dist is not None:
fig, ax = plt.subplots()
fig.suptitle('Diversity scatter plot')
dist_proj = sklearn.manifold.TSNE(metric='precomputed', perplexity=7, random_state=0).fit_transform(dist)
ax.scatter(dist_proj[:, 0], dist_proj[:, 1], s=4, c='r')
plt.show()