Skip to content

Commit

Permalink
feat: saves actions and makes it possible to play genome back in all …
Browse files Browse the repository at this point in the history
…environments
  • Loading branch information
Vetlets05 committed Nov 5, 2024
1 parent 883bb80 commit 55085aa
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 6 deletions.
37 changes: 37 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src.environments.debug_env import env_debug_init, run_game_debug
from src.environments.playback_env import env_playback_init, run_game_playback
from src.utils.config import Config
from src.genetics.NEAT import NEAT
from src.utils.utils import read_fitness_file, save_fitness, save_best_genome, load_best_genome, save_neat, load_neat, save_fitness_graph_file, get_latest_generation
Expand Down Expand Up @@ -34,6 +35,34 @@ def test_genome(from_gen: int, to_gen: int, neat_name: str):
fitness = run_game_debug(env, state, genome, neat_name)
print(fitness)


def playback_genomes(args):
neat_name = args.neat_name if args.neat_name != '' else 'latest'

if args.to_gen is not None:
from_gen = args.from_gen if args.from_gen is not None else 0
playback_genome(from_gen, args.to_gen, neat_name, args.environment)
return

if args.from_gen is not None:
latest_gen = get_latest_generation(neat_name)
playback_genome(args.from_gen, latest_gen, neat_name, args.environment)
return

genome = load_best_genome(args.generation if args.generation is not None else -1, neat_name)
env, state = env_playback_init(args.environment)
run_game_playback(env, state, genome, neat_name, visualize=False)

def playback_genome(from_gen: int, to_gen: int, neat_name: str, environment: int):
for i in range(from_gen, to_gen + 1):
print(f"Playback genome {i}...")
genome = load_best_genome(i, neat_name)
env, state = env_playback_init(environment)
fitness = run_game_playback(env, state, genome, neat_name)
print(fitness)



def collect_fitnesses(genomes, generation, min_fitnesses, avg_fitnesses, best_fitnesses, neat_name):
fitnesses = [genome.fitness_value for genome in genomes]

Expand Down Expand Up @@ -133,6 +162,12 @@ def command_line_interface():
play_parser.add_argument('-g', '--generation', type=int, help="The generation of the genome to play")
play_parser.add_argument('-f', '--from_gen', type=int, help="The starting genome to test")
play_parser.add_argument('-t', '--to_gen', type=int, help="The ending genome to test (exclusive)")

playback_parser = subparsers.add_parser('playback', help="Play back the best genome from the lastest generation on an environment of your choice")
playback_parser.add_argument('-g', '--generation', type=int, help="The generation of the genome to play")
playback_parser.add_argument('-f', '--from_gen', type=int, help="The starting genome to test")
playback_parser.add_argument('-t', '--to_gen', type=int, help="The ending genome to test (exclusive)")
playback_parser.add_argument('-e', '--environment', type=int, help="The environment to play back the actions (0,1,2,3)")

args = parser.parse_args()

Expand All @@ -142,6 +177,8 @@ def command_line_interface():
save_fitness_graph_file(args.neat_name, show=True)
elif args.command == "play":
play_genome(args)
elif args.command == "playback":
playback_genomes(args)
else:
parser.print_help()

Expand Down
55 changes: 55 additions & 0 deletions src/environments/playback_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import warnings
import time
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from typing import Tuple
from src.genetics.genome import Genome
from src.genetics.traverse import Traverse
from src.environments.fitness_function import Fitness
from src.environments.mario_env import MarioJoypadSpace
from src.visualization.visualize_genome import visualize_genome
from src.utils.utils import save_state_as_png
from src.utils.utils import insert_input


def env_playback_init(environment) -> Tuple[MarioJoypadSpace, np.ndarray]:
"Initialize the super-mario environment in human_mode"
ENV_NAME = f"SuperMarioBros-v{environment}"
warnings.filterwarnings("ignore")
env = gym_super_mario_bros.make(ENV_NAME)
env = MarioJoypadSpace(env, SIMPLE_MOVEMENT) # Select available actions for AI
env.metadata['render_modes'] = "rgb_array"
env.metadata['render_fps'] = 144

state = env.reset() # Good practice to reset the env before using it.
return env, state


def run_game_playback(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Genome, neat_name: str, visualize: bool = True, frame_queue=None) -> float:

fitness = Fitness()
i = 0

while True:
if i >= len(genome.actions)-1:
env.close()
return fitness.get_fitness()
action = genome.actions[i]
time.sleep(0.01)
sr = env.step(action) # State, Reward, Done, Info

# timeout = 600 + sr.info["x_pos"]
if visualize and i % 1 == 0:
save_state_as_png(i, sr.state, neat_name)
visualize_genome(genome, neat_name, 0)

fitness.calculate_fitness(sr.info, action)

fitness_val: float = fitness.get_fitness()
print(fitness_val)


env.render()
i += 1
insert_input(genome, sr.state)
3 changes: 2 additions & 1 deletion src/environments/train_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_game(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Genome):
action = forward.traverse()
# time.sleep(0.001)
sr = env.step(action) # State, Reward, Done, Info
genome.add_action(action)

fitness.calculate_fitness(sr.info, action)

Expand All @@ -49,6 +50,6 @@ def run_game(env: MarioJoypadSpace, initial_state: np.ndarray, genome: Genome):

if sr.info["life"] == 1 or stagnation_counter > 150:
env.close()
return fitness.get_fitness()
return fitness.get_fitness(), genome.actions

insert_input(genome, sr.state)
7 changes: 4 additions & 3 deletions src/genetics/NEAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def train_genome(self, genome: Genome):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message=".*Gym version v0.24.")
env, state = env_init()
fitness = run_game(env, state, genome)
return genome.id, fitness # Return the genome's ID and its fitness
fitness, actions = run_game(env, state, genome)
return genome.id, fitness, actions # Return the genome's ID and its fitness

def rank_species(self, specie: Species) -> float:
"""Combine best genome fitness and average fitness to rank species."""
Expand All @@ -123,10 +123,11 @@ def train_genomes(self):

# results = [self.train_genome(genome) for genome in self.genomes] # Uncomment for single process

for genome_id, fitness in results: # Update genomes with the returned fitness values
for genome_id, fitness, actions in results: # Update genomes with the returned fitness values
for genome in self.genomes:
if genome.id == genome_id: # Match the genome by its ID
genome.fitness_value = fitness # Assign the fitness value
genome.actions = actions
break # Move to the next result once a match is found

def sort_species(self, genomes: List[Genome]):
Expand Down
7 changes: 6 additions & 1 deletion src/genetics/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self, id: int):
self.fitness_value: float = 0.0
self.elite = False

self.actions = []

def add_node(self, node: Node):
if node.type == 'input':
self.input_nodes.append(node)
Expand Down Expand Up @@ -136,12 +138,15 @@ def check_existing_connection(self, node1: Node, node2: Node):
return connection
return None

def add_action(self, action):
self.actions.append(action)

@property
def nodes(self) -> List[Node]:
""" Returns all nodes in the genome. """
return self.output_nodes + self.input_nodes + self.hidden_nodes

def __repr__(self):
return (f"Genome(id={self.id}, hidden nodes={[node.id for node in self.hidden_nodes]}, "
f"connections={[connection for connection in self.connections]}, fitness_value={self.fitness_value})")
f"connections={[connection for connection in self.connections]}, fitness_value={self.fitness_value}, actions={self.actions})")

2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Config:
c2: float = 1.5
c3: float = 0.4
genomic_distance_threshold: float = 2.69
population_size: int = 1 # 56 cores on IDUN
population_size: int = 4 # 56 cores on IDUN
generations: int = 2 # A bunch of iterations

connection_weight_mutation_chance: float = 0.8
Expand Down

0 comments on commit 55085aa

Please sign in to comment.