From 2f0798db71510a5908239b789fa259b652fc524d Mon Sep 17 00:00:00 2001 From: Vetlets05 Date: Thu, 24 Oct 2024 19:25:35 +0200 Subject: [PATCH] feat: merge conflict incoming (add graph and arguments) --- gui_with_pygame.py | 31 ++++++++++++++++++----- src/environments/mario_env.py | 2 +- src/utils/utils.py | 3 +++ src/visualization/colors_visualization.py | 2 ++ src/visualization/visualize_genome.py | 2 ++ 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/gui_with_pygame.py b/gui_with_pygame.py index 17148cf..300b0ac 100644 --- a/gui_with_pygame.py +++ b/gui_with_pygame.py @@ -3,6 +3,7 @@ import sys import threading import os +import argparse import src.utils.config as conf import pickle @@ -404,7 +405,7 @@ def __init__(self): self.fitness_graph = ImageSprite("data/fitness/fitness_plot.png", (700, 100)) except: print("ERROR: Could not find image path") - self.fitness_graph = 0 + self.fitness_graph = ImageSprite("genome_frames/genome_0.png", (700, 100)) @@ -454,7 +455,7 @@ def __init__(self): self.watch_genes_visualize = ImageSprite("genome_frames/genome_0.png", (700, 50), (600, 400)) ## Run button - self.run_button = Button(600, 20, 200, 50, "Run Selected Genomes", st.font, st.text_color, st.button_color, st.hover_color, st.pressed_color, self.start_watching_process) + self.run_button = Button(600, 20, 200, 50, "Run Selected Genomes", st.font, st.text_color, st.button_color, st.hover_color, st.pressed_color, self.run_selected_genomes) self.watch_back_button = Button(200, 450, 200, 50, "Back", st.font, st.text_color, st.button_color, st.hover_color, st.pressed_color, self.main_menu_scene) #Visualize best genome @@ -586,7 +587,20 @@ def start_training(self): generations = self.training_input_fields[2].text print(f"Starting training with Population: {population}, Mutation Rate: {mutation_rate}, Generations: {generations}") # Insert your training code here (e.g. NEAT training) - neat_test_file.main() + + + parser = argparse.ArgumentParser(description="Train or Test Genomes") + subparsers = parser.add_subparsers(dest="command", help="Choose 'train', 'test', 'graph', or 'play'") + + # Train command (runs main()) + train_parser = subparsers.add_parser('train', help="Run the training process") + train_parser.add_argument('-n', '--neat_name', type=str, default='', help="The name of the NEAT object to load from 'trained_population/'") + train_parser.add_argument('-g', '--n_generations', type=int, default=-1, help="The number of generations to train for") + + sim_args = ["train", "-g", generations] + args=parser.parse_args(sim_args) + + neat_test_file.main(args=args) def draw_settings_scene(self): @@ -617,11 +631,16 @@ def update_screen(self): element.draw(self.screen) for element in self.training_UI: element.draw(self.screen) - #self.fitness_graph.draw(self.screen) + + try: + self.fitness_graph = ImageSprite("data/fitness/fitness_plot.png", (700, 100)) + except: + pass + self.fitness_graph.draw(self.screen) gen_data,best_fitness_data,avg_fitness_data,min_fitness_data = util.read_fitness_file("data/fitness/fitness_values.txt") - self.new_fitness_graph = Graph(self.screen, (700, 100), (400, 400), best_fitness_data) + #self.new_fitness_graph = Graph(self.screen, (700, 100), (400, 400), best_fitness_data) - self.new_fitness_graph.draw() + #self.new_fitness_graph.draw() elif st.sc_selector == 2: diff --git a/src/environments/mario_env.py b/src/environments/mario_env.py index 0336938..969f18c 100644 --- a/src/environments/mario_env.py +++ b/src/environments/mario_env.py @@ -6,7 +6,7 @@ from gym.spaces import Box from gym.core import ObservationWrapper import numpy as np -# import cv2 +import cv2 from skimage.transform import resize class StepResult(NamedTuple): diff --git a/src/utils/utils.py b/src/utils/utils.py index f1cf860..971f88e 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,5 +1,8 @@ from src.genetics.genome import Genome from src.utils.config import Config +import matplotlib +matplotlib.use("Agg") + import matplotlib.pyplot as plt import numpy as np import os diff --git a/src/visualization/colors_visualization.py b/src/visualization/colors_visualization.py index 9956aaa..af05ec5 100644 --- a/src/visualization/colors_visualization.py +++ b/src/visualization/colors_visualization.py @@ -1,6 +1,8 @@ from src.utils.utils import normalize_negative_values, normalize_positive_values import numpy as np from typing import List +import matplotlib +matplotlib.use("Agg") import matplotlib.cm as cm from src.genetics.node import Node diff --git a/src/visualization/visualize_genome.py b/src/visualization/visualize_genome.py index 79d29a5..5a3734d 100644 --- a/src/visualization/visualize_genome.py +++ b/src/visualization/visualize_genome.py @@ -1,4 +1,6 @@ import networkx as nx +import matplotlib +matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.axes import Axes from src.genetics.genome import Genome