Skip to content

Commit

Permalink
Merge pull request #31 from CogitoNTNU/feat/improved-main
Browse files Browse the repository at this point in the history
Feat/improved main
  • Loading branch information
BrageHK authored Oct 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
2 parents de31df6 + 79eca55 commit 3093d26
Showing 4 changed files with 46 additions and 26 deletions.
Binary file removed data/fitness/fitness_plot.png
Binary file not shown.
Binary file added data_copy/fitness/fitness_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 26 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from numpy import save
from src.genetics.genome import Genome
from src.environments.debug_env import env_debug_init, run_game_debug
from src.utils.config import Config
from src.genetics.NEAT import NEAT
from src.utils.utils import save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_data
from src.utils.utils import save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_data, get_fitnesses_from_file
import warnings
import cProfile
import pstats
import argparse
from typing import Dict

warnings.filterwarnings("ignore", category=UserWarning, message=".*Gym version v0.24.1.*")

def play_genome():
genome = load_best_genome(0)
genome = load_best_genome(-1)
env, state = env_debug_init()
run_game_debug(env, state, genome, 0, visualize=False)

@@ -39,23 +36,31 @@ def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fi
print(f"Generation: {generation} - Best: {max_fitness} - Avg: {avg_fitness} - Min: {min_fitness}")


def main(neat_name: str = '', to_generations: int = 0):
def main(args):
neat_name = args.neat_name
profiler = cProfile.Profile()
profiler.enable()
min_fitnesses, avg_fitnesses, best_fitnesses = [], [], []
to_generations = args.extra_number

if neat_name == '':
config_instance = Config()
neat_name = "latest"

neat, exists = load_neat(neat_name)
config_instance = Config()
if exists:
generation_nums, best_fitnesses, avg_fitnesses, min_fitnesses = get_fitnesses_from_file("fitness_values")
print(f"Generation numbs: {generation_nums}")
from_generation = generation_nums[-1] + 1
print(f"From generation: {from_generation}")
#config_instance = neat.config
else:
neat = NEAT(config_instance)
neat.initiate_genomes()
from_generation = 0
generations = config_instance.generations if to_generations == 0 else to_generations
else:
neat = load_neat(neat_name)
config_instance = neat.config
# from_generation = neat.generation
from_generation = 0
generations = (config_instance.generations - from_generation) if to_generations == 0 else to_generations

generations = (config_instance.generations) if to_generations == 0 else to_generations

min_fitnesses, avg_fitnesses, best_fitnesses = [], [], []
print(f"Training from generation {from_generation} to generation {from_generation + generations}")

try:
@@ -82,11 +87,12 @@ def main(neat_name: str = '', to_generations: int = 0):
for genome in neat.genomes:
if not genome.elite:
neat.add_mutation(genome)
save_fitness_data()
except KeyboardInterrupt:
print("\nProcess interrupted! Saving fitness data...")
finally:
# Always save fitness data before exiting, whether interrupted or completed
save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses)
save_fitness(best_fitnesses, avg_fitnesses, min_fitnesses, exists)
save_neat(neat, "latest")
print("Fitness data saved.")

@@ -115,12 +121,14 @@ def command_line_interface():

graph_parser = subparsers.add_parser('graph', help="Graph the fitness data")

play_parser = subparsers.add_parser('play', help="Play the best genome")
play_parser = subparsers.add_parser('play', help="Play the best genome from the lastest generation")
play_parser.add_argument('-g', '--generation', type=int, help="The generation of the genome to play")
play_parser.add_argument('-b', '--best', action='store_true', help="Play the best genome")

args = parser.parse_args()

if args.command == "train":
main(neat_name=args.neat_name, to_generations=args.extra_number)
main(args) #neat_name=args.neat_name, to_generations=args.extra_number)
elif args.command == "test":
test_genome(args.from_gen, args.to_gen)
elif args.command == "graph":
28 changes: 20 additions & 8 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,10 @@

def save_state_as_png(i, state: np.ndarray) -> None:
"""Save a frame."""
directory = "./mario_frames"
directory = "./data/mario_frames"
if not os.path.exists(directory):
os.makedirs(directory)
plt.imsave(f"./mario_frames/frame{i}.png", state, cmap='gray', vmin=0, vmax=1)
plt.imsave(f"./data/mario_frames/frame{i}.png", state, cmap='gray', vmin=0, vmax=1)

def normalize_positive_values(positive_vals: np.ndarray) -> None:
"""Takes an ndarray with positive floats as inputs,
@@ -47,6 +47,7 @@ def insert_input(genome:Genome, state: np.ndarray) -> None:
node.value = state[i//num_columns][i % num_columns]

def save_fitness(best: list, avg: list, min: list, append: bool = False):
os.makedirs('data/fitness', exist_ok=True)
with open("data/fitness/fitness_values.txt", "w") as f:
for i in range(len(best)):
f.write(f"Generation: {i} - Best: {best[i]} - Avg: {avg[i]} - Min: {min[i]}\n")
@@ -57,6 +58,11 @@ def save_best_genome(genome: Genome, generation: int):
pickle.dump(genome, f) # type: ignore

def load_best_genome(generation: int):
if generation == -1:
files = os.listdir('data/good_genomes')
pattern = re.compile(r'best_genome_(\d+).obj')
generation = max([int(pattern.match(file).group(1)) for file in files])
print("loading best genome from generation: ", generation)
with open(f'data/good_genomes/best_genome_{generation}.obj', 'rb') as f:
return pickle.load(f)

@@ -66,15 +72,19 @@ def save_neat(neat: 'NEAT', name: str):
pickle.dump(neat, f) # type: ignore

def load_neat(name: str):
# Check if file exists first
if not os.path.exists(f'data/trained_population/neat_{name}.obj'):
return None, False
with open(f'data/trained_population/neat_{name}.obj', 'rb') as f:
return pickle.load(f) # type: ignore
return pickle.load(f), True # type: ignore

# Function to read and parse the file
def read_fitness_file(filename):
generations = []
best_values = []
avg_values = []
min_values = []
os.makedirs('data/fitness', exist_ok=True)

# Open the file and extract data
with open(filename, 'r') as file:
@@ -90,6 +100,8 @@ def read_fitness_file(filename):

# Function to plot the data
def plot_fitness_data(generations, best_values, avg_values, min_values):
plt.clf()

plt.plot(generations, best_values, label='Best')
plt.plot(generations, avg_values, label='Avg')
plt.plot(generations, min_values, label='Min')
@@ -98,15 +110,15 @@ def plot_fitness_data(generations, best_values, avg_values, min_values):
plt.ylabel('Values')
plt.title('Generation vs Best, Avg, and Min')

# Ensure x-axis ticks are integers
plt.xticks(generations) # Set x-ticks to be the generation numbers (integers)

plt.legend()
plt.grid(True)
plt.savefig('data/fitness/fitness_plot.png')
plt.show()

def save_fitness_data():
filename = 'data/fitness/fitness_values.txt' # Make sure the file is named 'fitness.txt' and is in the same directory
generations, best_values, avg_values, min_values = read_fitness_file(filename)
plot_fitness_data(generations, best_values, avg_values, min_values)
plot_fitness_data(generations, best_values, avg_values, min_values)

def get_fitnesses_from_file(f_name: str):
filename = f'data/fitness/{f_name}.txt' # Make sure the file is named 'fitness.txt' and is in the same directory
return read_fitness_file(filename)

0 comments on commit 3093d26

Please sign in to comment.