Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kapi0okapi committed Oct 24, 2024
2 parents c5b20dc + d00469e commit e6bfef7
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 129 deletions.
Binary file added docs/images/logo light blue name white.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
204 changes: 149 additions & 55 deletions gui_with_pygame.py

Large diffs are not rendered by default.

68 changes: 28 additions & 40 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from genericpath import exists
from src.environments.debug_env import env_debug_init, run_game_debug
from src.utils.config import Config
from src.genetics.NEAT import NEAT
from src.utils.utils import save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_data, get_fitnesses_from_file
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_graph_file
import warnings
import cProfile
import pstats
Expand All @@ -18,12 +17,13 @@ def play_genome(args):
else:
from_gen = 0
test_genome(from_gen, to_gen)

neat_name = "latest"
if args.neat_name != '':
neat_name = args.neat_name

if args.generation is not None:
generation_num = args.generation
else:
generation_num = -1
genome = load_best_genome(generation_num)
generation_num = args.generation if args.generation is not None else -1
genome = load_best_genome(generation_num, neat_name)
env, state = env_debug_init()
run_game_debug(env, state, genome, 0, visualize=False)

Expand All @@ -35,11 +35,11 @@ def test_genome(from_gen: int, to_gen: int):
fitness = run_game_debug(env, state, genome, i)
print(fitness)

def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses):
def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses, neat_name):
fitnesses = [genome.fitness_value for genome in genomes]

best_genome = max(genomes, key=lambda genome: genome.fitness_value)
save_best_genome(best_genome, generation)
save_best_genome(best_genome, generation, neat_name)

min_fitness, avg_fitness, max_fitness = min(fitnesses), sum(fitnesses) / len(fitnesses), max(fitnesses)
best_fitnesses.append(max_fitness)
Expand All @@ -51,69 +51,56 @@ def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fi

def main(args):
neat_name = args.neat_name
print("\nTraining NEAT with name: ", neat_name)
print()
profiler = cProfile.Profile()
profiler.enable()
min_fitnesses, avg_fitnesses, best_fitnesses = [], [], []

if neat_name == '':
neat_name = "latest"

neat = load_neat(neat_name)
config_instance = Config()
if neat is None:
exists = False
generation_nums, best_fitnesses, avg_fitnesses, min_fitnesses = get_fitnesses_from_file("fitness_values")
print(f"Generation numbs: {generation_nums}")
if neat is not None: # TODO: Add option to insert new config into NEAT object.
generation_nums, best_fitnesses, avg_fitnesses, min_fitnesses = read_fitness_file(neat_name)
from_generation = generation_nums[-1] + 1
print(f"From generation: {from_generation}")
#config_instance = neat.config
else:
exists = True
neat = NEAT(config_instance)
neat = NEAT(Config())
neat.initiate_genomes()
from_generation = 0

generations = (config_instance.generations) if args.n_generations == -1 else args.n_generations
generations = (neat.config.generations) if args.n_generations == -1 else args.n_generations

print(f"Training from generation {from_generation} to generation {from_generation + generations}")
print(f"Training from generation {from_generation} to generation {from_generation + generations}\n")

try:
for generation in range(from_generation, from_generation + generations):
neat.train_genomes()
collect_fitnesses(neat.genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses)
collect_fitnesses(neat.genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses, neat_name)

neat.sort_species(neat.genomes)
neat.check_population_improvements()
neat.check_individual_impovements() # Check if the species are improving, remove the ones that are not after 15 generations
neat.adjust_fitness()

save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses)
save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses, neat_name)
neat.calculate_number_of_children_of_species()
new_genomes_list = []
for specie in neat.species:
new_genomes_list.append(neat.generate_offspring(specie))

flattened_genomes = [genome for sublist in new_genomes_list for genome in sublist]
neat.genomes = flattened_genomes
neat.genomes = [genome for specie in neat.species for genome in neat.generate_offspring(specie)] # Generate offspring in each specie

print(f"new generation size: {len(neat.genomes)}" )

for genome in neat.genomes:
if not genome.elite:
neat.add_mutation(genome)
save_fitness_data()
save_fitness_graph_file(neat_name)
except KeyboardInterrupt:
print("\nProcess interrupted! Saving fitness data...")
finally:
# Always save fitness data before exiting, whether interrupted or completed
save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses, exists)
save_neat(neat, "latest")
save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses, neat_name)
save_neat(neat, neat_name)
print("Fitness data saved.")

profiler.disable()

# Create a stats object to print out profiling results
stats = pstats.Stats(profiler).sort_stats('cumtime')
stats = pstats.Stats(profiler).sort_stats('cumtime') # Create a stats object to print out profiling results
stats.print_stats()

return neat.genomes
Expand All @@ -123,9 +110,11 @@ def command_line_interface():

subparsers = parser.add_subparsers(dest="command", help="Choose 'train', 'test', 'graph', or 'play'")

# Global arguments for all functions
parser.add_argument('-n', '--neat_name', type=str, default='latest', help="The name of the NEAT object to load from 'trained_population/'")

# Train command (runs main())
train_parser = subparsers.add_parser('train', help="Run the training process")
train_parser.add_argument('-n', '--neat_name', type=str, default='', help="The name of the NEAT object to load from 'trained_population/'")
train_parser.add_argument('-g', '--n_generations', type=int, default=-1, help="The number of generations to train for")

graph_parser = subparsers.add_parser('graph', help="Graph the fitness data")
Expand All @@ -140,12 +129,11 @@ def command_line_interface():
if args.command == "train":
main(args)
elif args.command == "graph":
save_fitness_data(show=True)
save_fitness_graph_file(args.neat_name, show=True)
elif args.command == "play":
play_genome(args)
else:
parser.print_help()

if __name__ == "__main__":
# main()
command_line_interface()
command_line_interface()
2 changes: 1 addition & 1 deletion src/environments/mario_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from gym.spaces import Box
from gym.core import ObservationWrapper
import numpy as np
# import cv2
import cv2
from skimage.transform import resize

class StepResult(NamedTuple):
Expand Down
19 changes: 16 additions & 3 deletions src/genetics/NEAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def check_population_improvements(self):
return
self.improvement_counter += 1
print(f"Improvement counter: {self.improvement_counter}")
if self.improvement_counter > 20:
if self.improvement_counter >= 20:
print("No improvements in 20 generations - RIP")
self.species.sort(key=lambda x: self.rank_species(x), reverse=True)
if len(self.species) > 2:
Expand All @@ -194,23 +194,36 @@ def check_population_improvements(self):

def check_individual_impovements(self):
""" Check each species if they are improving, remove the ones that are not after 15 generations"""
best_fitness = 0
for specie in self.species:
specie.genomes.sort(key=lambda x: x.fitness_value, reverse=True)
if specie.best_genome_fitness < specie.genomes[0].fitness_value:
specie.best_genome_fitness = specie.genomes[0].fitness_value
specie.improvement_counter = 0
else:
specie.improvement_counter += 1
if specie.best_genome_fitness > best_fitness:
best_fitness = specie.best_genome_fitness
for specie in self.species:
print(f"Specie number: {specie.species_number}, Best fitness: {specie.best_genome_fitness}, Improvement counter: {specie.improvement_counter}")
if specie.improvement_counter > 20:
print(f"Specie number: {specie.species_number}, Best fitness: {specie.best_genome_fitness}, Stagnation counter: {specie.improvement_counter}")
if specie.improvement_counter >= 15:
if specie.best_genome_fitness == best_fitness:
specie.improvement_counter = 0
print("Best specie is not removed")
break
print("Specie has stagnated, removing it!")
self.species.remove(specie)
# TODO: Ludvig, kan du forklare hvorfor det er sånn? En liten specie som er shit burde vel ikke overleve videre?
# Bare slett kommentarene hvis det ikke skal være der.
"""
if len(self.species) > 2:
self.species.remove(specie)
else:
# if there are only two species left, reset their imrovement counters
self.improvement_counter = 0
for specie in self.species:
specie.improvement_counter = 0
"""
if self.species == []:
print("All species have been removed - RIP")
self.initiate_genomes()
Expand Down
6 changes: 3 additions & 3 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class Config:
c2: float = 1.5
c3: float = 0.4
genomic_distance_threshold: float = 2.69
population_size: int = 20
generations: int = 1
population_size: int = 120
generations: int = 50

connection_weight_mutation_chance: float = 0.8
# if mutate gene:
Expand All @@ -34,7 +34,7 @@ class Config:
# Activation function
# What we use: ReLU
# Paper: 1/(1+exp(-0.49*x))
activation_func: str = "sigmoid"
activation_func: str = "tanh"

elitism_rate: float = 0.2 # percentage of the best genomes are copied to the next generation
remove_worst_percentage: float = 0.3 # percentage of the worst genomes are removed from the population when breeding
Expand Down
58 changes: 31 additions & 27 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from src.genetics.genome import Genome
from src.utils.config import Config
import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
import numpy as np
import os
Expand Down Expand Up @@ -46,20 +49,22 @@ def insert_input(genome:Genome, state: np.ndarray) -> None:
for i, node in enumerate(genome.nodes[start_idx_input_node:start_idx_input_node+num_input_nodes]): # get all input nodes
node.value = state[i//num_columns][i % num_columns]

def save_fitness(best: list, avg: list, min: list, append: bool = False):
os.makedirs('data/fitness', exist_ok=True)
with open("data/fitness/fitness_values.txt", "w") as f:
def save_fitness(best: list, avg: list, min: list, name: str):
os.makedirs(f'data/{name}/fitness', exist_ok=True)
with open(f"data/{name}/fitness/fitness_values.txt", "w") as f:
for i in range(len(best)):
f.write(f"Generation: {i} - Best: {best[i]} - Avg: {avg[i]} - Min: {min[i]}\n")

def save_best_genome(genome: Genome, generation: int):
os.makedirs('data/good_genomes', exist_ok=True)
with open(f'data/good_genomes/best_genome_{generation}.obj', 'wb') as f:
def save_best_genome(genome: Genome, generation: int, name: str):
path = f'data/{name}/good_genomes'
os.makedirs(path, exist_ok=True)
with open(f'{path}/best_genome_{generation}.obj', 'wb') as f:
pickle.dump(genome, f) # type: ignore

def load_best_genome(generation: int):
if generation == -1:
files = os.listdir('data/good_genomes')
def load_best_genome(generation: int, name: str) -> Genome:
"""Loads the best genome from the given generation. If -1 is passed as argument, the latest generation is displayed."""
if generation == -1: # Find the genome from the latest generation.
files = os.listdir(f'data/{name}/good_genomes')
pattern = re.compile(r'best_genome_(\d+).obj')
generations = []
for file in files:
Expand All @@ -72,31 +77,35 @@ def load_best_genome(generation: int):
else:
raise FileNotFoundError("No valid genome files found in 'data/good_genomes'.")

with open(f'data/good_genomes/best_genome_{generation}.obj', 'rb') as f:
with open(f'data/{name}/good_genomes/best_genome_{generation}.obj', 'rb') as f:
return pickle.load(f)

def save_neat(neat: 'NEAT', name: str):
os.makedirs('data/trained_population', exist_ok=True)
with open(f'data/trained_population/neat_{name}.obj', 'wb') as f:
os.makedirs(f'data/{name}/trained_population', exist_ok=True)
with open(f'data/{name}/trained_population/neat_{name}.obj', 'wb') as f:
pickle.dump(neat, f) # type: ignore

def load_neat(name: str):
# Check if file exists first
if not os.path.exists(f'data/trained_population/neat_{name}.obj'):
if not os.path.exists(f'data/{name}/trained_population/neat_{name}.obj'):
return None
with open(f'data/trained_population/neat_{name}.obj', 'rb') as f:
with open(f'data/{name}/trained_population/neat_{name}.obj', 'rb') as f:
return pickle.load(f) # type: ignore

# Function to read and parse the file
def read_fitness_file(filename):
def read_fitness_file(name: str):
"""
name - Name of the neat instance.
"""
generations = []
best_values = []
avg_values = []
min_values = []
os.makedirs('data/fitness', exist_ok=True)
filename = f'data/{name}/fitness' # Make sure the file is named 'fitness.txt' and is in the same directory
os.makedirs(filename, exist_ok=True)

# Open the file and extract data
with open(filename, 'r') as file:
with open(f"{filename}/fitness_values.txt", 'r') as file:
for line in file:
match = re.match(r"Generation: (\d+) - Best: ([\d\.]+) - Avg: ([\d\.]+) - Min: ([\d\.]+)", line)
if match:
Expand All @@ -108,7 +117,7 @@ def read_fitness_file(filename):
return generations, best_values, avg_values, min_values

# Function to plot the data
def plot_fitness_data(generations, best_values, avg_values, min_values, show=False):
def plot_fitness_data(generations: list, best_values: list, avg_values: list, min_values: list, name: str, show=False):
plt.clf()

plt.plot(generations, best_values, label='Best')
Expand All @@ -121,15 +130,10 @@ def plot_fitness_data(generations, best_values, avg_values, min_values, show=Fal

plt.legend()
plt.grid(True)
plt.savefig('data/fitness/fitness_plot.png')
plt.savefig(f'data/{name}/fitness/fitness_plot.png')
if show:
plt.show()

def save_fitness_data(show=False):
filename = 'data/fitness/fitness_values.txt' # Make sure the file is named 'fitness.txt' and is in the same directory
generations, best_values, avg_values, min_values = read_fitness_file(filename)
plot_fitness_data(generations, best_values, avg_values, min_values, show=show)

def get_fitnesses_from_file(f_name: str):
filename = f'data/fitness/{f_name}.txt' # Make sure the file is named 'fitness.txt' and is in the same directory
return read_fitness_file(filename)
def save_fitness_graph_file(name, show=False):
generations, best_values, avg_values, min_values = read_fitness_file(name)
plot_fitness_data(generations, best_values, avg_values, min_values, name, show=show)
2 changes: 2 additions & 0 deletions src/visualization/colors_visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from src.utils.utils import normalize_negative_values, normalize_positive_values
import numpy as np
from typing import List
import matplotlib
matplotlib.use("Agg")
import matplotlib.cm as cm
from src.genetics.node import Node

Expand Down
2 changes: 2 additions & 0 deletions src/visualization/visualize_genome.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import networkx as nx
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from src.genetics.genome import Genome
Expand Down

0 comments on commit e6bfef7

Please sign in to comment.