Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Hako2807 committed Oct 24, 2024
2 parents ecd9d11 + 512fe42 commit e978805
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.environments.debug_env import env_debug_init, run_game_debug
from src.utils.config import Config
from src.genetics.NEAT import NEAT
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_data
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_graph_file
import warnings
import cProfile
import pstats
Expand Down Expand Up @@ -51,16 +51,16 @@ def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fi

def main(args):
neat_name = args.neat_name
print("\nTraining NEAT with name: ", neat_name)
print()
profiler = cProfile.Profile()
profiler.enable()
min_fitnesses, avg_fitnesses, best_fitnesses = [], [], []

neat = load_neat(neat_name)
if neat is not None: # TODO: Add option to insert new config into NEAT object.
generation_nums, best_fitnesses, avg_fitnesses, min_fitnesses = read_fitness_file(neat_name)
print(f"Generation numbs: {generation_nums}")
from_generation = generation_nums[-1] + 1
print(f"From generation: {from_generation}")
#config_instance = neat.config
else:
neat = NEAT(Config())
Expand All @@ -69,7 +69,7 @@ def main(args):

generations = (neat.config.generations) if args.n_generations == -1 else args.n_generations

print(f"Training from generation {from_generation} to generation {from_generation + generations}")
print(f"Training from generation {from_generation} to generation {from_generation + generations}\n")

try:
for generation in range(from_generation, from_generation + generations):
Expand All @@ -90,7 +90,7 @@ def main(args):
for genome in neat.genomes:
if not genome.elite:
neat.add_mutation(genome)
save_fitness_data(neat_name)
save_fitness_graph_file(neat_name)
except KeyboardInterrupt:
print("\nProcess interrupted! Saving fitness data...")
finally:
Expand All @@ -115,7 +115,6 @@ def command_line_interface():

# Train command (runs main())
train_parser = subparsers.add_parser('train', help="Run the training process")
train_parser.add_argument('-n', '--neat_name', type=str, default='', help="The name of the NEAT object to load from 'trained_population/'")
train_parser.add_argument('-g', '--n_generations', type=int, default=-1, help="The number of generations to train for")

graph_parser = subparsers.add_parser('graph', help="Graph the fitness data")
Expand All @@ -130,7 +129,7 @@ def command_line_interface():
if args.command == "train":
main(args)
elif args.command == "graph":
save_fitness_data(args.neat_name, show=True)
save_fitness_graph_file(args.neat_name, show=True)
elif args.command == "play":
play_genome(args)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ def plot_fitness_data(generations: list, best_values: list, avg_values: list, mi
if show:
plt.show()

def save_fitness_data(name, show=False):
def save_fitness_graph_file(name, show=False):
generations, best_values, avg_values, min_values = read_fitness_file(name)
plot_fitness_data(generations, best_values, avg_values, min_values, name, show=show)

0 comments on commit e978805

Please sign in to comment.