Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Vetlets05 committed Oct 30, 2024
2 parents b8aafaf + 98251d6 commit 9e3848b
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 49 deletions.
15 changes: 10 additions & 5 deletions gui_with_pygame.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,13 @@ def __init__(self, x, y, width, height, items, font, text_color, bg_color, selec
self.scroll_offset = 0 # This tracks where we are in the list
self.padding = padding # Distance between items
# Create a base list of items without positioning them
print(items[0])

self.list_items = [
SelectableListItem(
x, y, width, (height - (visible_count - 1) * padding) // visible_count, # Adjust height to account for padding

genome_id=item[1], fitness=item[0].fitness_value, font=font, text_color=text_color,
bg_color=bg_color, selected_color=selected_color
) for item in items
) for item in self.items
]

def draw(self, screen):
Expand Down Expand Up @@ -171,6 +170,12 @@ def handle_event(self, event):
self.scroll("up")
elif event.y < 0:
self.scroll("down")


def get_selected_genomes(self):
# Return a list of selected genome IDs
print(f"Item: {self.items[0]}")
return [item[1] for item in self.items if item[0].selected]


class GenomeViewer:
Expand Down Expand Up @@ -445,7 +450,7 @@ def __init__(self):
## Genome viewer
self.genome_viewer = GenomeViewer(self.genomes, st.normal_font, st.text_color, st.input_field_bg, st.input_field_active_bg)

self.watch_genes_visualize = ImageSprite("data/genome_frames/genome_0.png", (600, 100), (600, 400))
self.watch_genes_visualize = ImageSprite("data/latest/genome_frames/genome_0.png", (600, 100), (600, 400))

## Run button
self.run_button = Button(600, 20, 200, 50, "Run Selected Genomes", st.normal_font, st.text_color, st.button_color, st.hover_color, st.pressed_color, self.run_selected_genomes)
Expand Down Expand Up @@ -498,7 +503,7 @@ def visualize_genome_scene(self):

def run_selected_genomes(self):
# Get the selected genome IDs and run them
selected_genomes = self.genome_viewer.get_selected_genomes()
selected_genomes = self.scrollable_list.get_selected_genomes()

print(f"Running genomes: {selected_genomes}")
# Add logic here to execute the selected genomes (e.g., visualize, simulate, etc.)
Expand Down
27 changes: 13 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.environments.debug_env import env_debug_init, run_game_debug
from src.utils.config import Config
from src.genetics.NEAT import NEAT
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_graph_file
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_graph_file, get_latest_generation
import warnings
import cProfile
import pstats
Expand All @@ -10,29 +10,28 @@
warnings.filterwarnings("ignore", category=UserWarning, message=".*Gym version v0.24.1.*")

def play_genome(args):
neat_name = "latest"
if args.neat_name != '':
neat_name = args.neat_name
neat_name = args.neat_name if args.neat_name != '' else 'latest'

if args.to_gen is not None:
to_gen = args.to_gen
if args.from_gen is not None:
from_gen = args.from_gen
else:
from_gen = 0
test_genome(from_gen, to_gen, neat_name)
from_gen = args.from_gen if args.from_gen is not None else 0
test_genome(from_gen, args.to_gen, neat_name)
return

if args.from_gen is not None:
latest_gen = get_latest_generation(neat_name)
test_genome(args.from_gen, latest_gen, neat_name)
return

generation_num = args.generation if args.generation is not None else -1
genome = load_best_genome(generation_num, neat_name)
genome = load_best_genome(args.generation if args.generation is not None else -1, neat_name)
env, state = env_debug_init()
run_game_debug(env, state, genome, 0, visualize=True)
run_game_debug(env, state, genome, neat_name, visualize=True)

def test_genome(from_gen: int, to_gen: int, neat_name: str):
for i in range(from_gen, to_gen + 1):
print(f"Testing genome {i}...")
genome = load_best_genome(i, neat_name)
env, state = env_debug_init()
fitness = run_game_debug(env, state, genome, i)
fitness = run_game_debug(env, state, genome, neat_name)
print(fitness)

def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses, neat_name):
Expand Down
53 changes: 53 additions & 0 deletions output.txt

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/environments/debug_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def env_debug_init() -> Tuple[MarioJoypadSpace, np.ndarray]:
state = env.reset() # Good practice to reset the env before using it.
return env, state

def run_game_debug(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Genome, num: int, visualize: bool = True) -> float:
def run_game_debug(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Genome, neat_name: str, visualize: bool = True) -> float:

forward = Traverse(genome)
fitness = Fitness()
Expand All @@ -40,8 +40,8 @@ def run_game_debug(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Gen
env.render()
# timeout = 600 + sr.info["x_pos"]
if visualize and i % 10000 == 0:
save_state_as_png(0, sr.state)
visualize_genome(genome, 0)
save_state_as_png(i, sr.state, neat_name)
visualize_genome(genome, neat_name, i)

fitness.calculate_fitness(sr.info, action)

Expand Down
6 changes: 1 addition & 5 deletions src/genetics/NEAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,14 @@ def check_individual_impovements(self):
print("Best specie is not removed")
break
print("Specie has stagnated, removing it!")
self.species.remove(specie)
# TODO: Ludvig, kan du forklare hvorfor det er sånn? En liten specie som er shit burde vel ikke overleve videre?
# Bare slett kommentarene hvis det ikke skal være der.
"""

if len(self.species) > 2:
self.species.remove(specie)
else:
# if there are only two species left, reset their imrovement counters
self.improvement_counter = 0
for specie in self.species:
specie.improvement_counter = 0
"""
if self.species == []:
print("All species have been removed - RIP")
self.initiate_genomes()
Expand Down
34 changes: 16 additions & 18 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# These imports will only be used for type hinting, not at runtime
from src.genetics.NEAT import NEAT

def save_state_as_png(i, state: np.ndarray) -> None:
def save_state_as_png(i, state: np.ndarray, neat_name: str) -> None:
"""Save a frame."""
directory = "./data/mario_frames"
directory = f"./data/{neat_name}/mario_frames"
if not os.path.exists(directory):
os.makedirs(directory)
plt.imsave(f"./data/mario_frames/frame{i}.png", state, cmap='gray', vmin=0, vmax=1)
plt.imsave(f"{directory}/frame{i}.png", state, cmap='gray', vmin=0, vmax=1)

def normalize_positive_values(positive_vals: np.ndarray) -> None:
"""Takes an ndarray with positive floats as inputs,
Expand Down Expand Up @@ -61,22 +61,20 @@ def save_best_genome(genome: Genome, generation: int, name: str):
with open(f'{path}/best_genome_{generation}.obj', 'wb') as f:
pickle.dump(genome, f) # type: ignore

def get_latest_generation(name: str) -> int:
"""Finds the latest generation number from saved genomes for a given NEAT name."""
files = os.listdir(f'data/{name}/good_genomes')
pattern = re.compile(r'best_genome_(\d+).obj')
generations = [int(match.group(1)) for file in files if (match := pattern.match(file))]
if not generations:
raise FileNotFoundError("No valid genome files found in 'data/good_genomes'.")
return max(generations)

def load_best_genome(generation: int, name: str) -> Genome:
"""Loads the best genome from the given generation. If -1 is passed as argument, the latest generation is displayed."""
if generation == -1: # Find the genome from the latest generation.
files = os.listdir(f'data/{name}/good_genomes')
pattern = re.compile(r'best_genome_(\d+).obj')
generations = []
for file in files:
match = pattern.match(file)
if match:
generations.append(int(match.group(1)))
if generations:
generation = max(generations)
print("Loading best genome from generation:", generation)
else:
raise FileNotFoundError("No valid genome files found in 'data/good_genomes'.")

"""Loads the best genome from a given generation or the latest if -1 is provided."""
if generation == -1:
generation = get_latest_generation(name)
print(f"Loading best genome from generation: {generation}")
with open(f'data/{name}/good_genomes/best_genome_{generation}.obj', 'rb') as f:
return pickle.load(f)

Expand Down
6 changes: 3 additions & 3 deletions src/visualization/visualize_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_labels_to_output_nodes(pos_dict: Dict[int, Tuple[float, float]], genome:
ax.text(x, y + 0.50, '+'.join(movement_list[output_node.id]), fontsize=10, fontweight='bold', color='blue') # Add label to the right


def visualize_genome(genome: Genome, frame_number: int):
def visualize_genome(genome: Genome, neat_name: str, frame_number: int):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)

Expand Down Expand Up @@ -88,10 +88,10 @@ def visualize_genome(genome: Genome, frame_number: int):

plt.xlim(GRAPH_XMIN, GRAPH_XMAX)
plt.ylim(GRAPH_YMIN, GRAPH_YMAX)
directory = "./data/genome_frames"
directory = f"./data/{neat_name}/genome_frames"
if not os.path.exists(directory):
os.makedirs(directory)
plt.savefig(f'./data/genome_frames/genome_{frame_number}.png')
plt.savefig(f'{directory}/genome_{frame_number}.png')
plt.close()


2 changes: 1 addition & 1 deletion test/visualization/visualize_genome_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def test_visualize_genome():
set_input_node_values(state, genome.nodes)
traverse = Traverse(genome)
traverse.traverse()
visualize_genome(genome, -1)
visualize_genome(genome, "test", -1)

0 comments on commit 9e3848b

Please sign in to comment.