Skip to content

Commit

Permalink
Merge pull request #18 from CogitoNTNU/visualize_genome
Browse files Browse the repository at this point in the history
Visualize genome
  • Loading branch information
ChristianFredrikJohnsen authored Sep 19, 2024
2 parents da864df + 2d9adb5 commit 0c80470
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Node:
A node in a neural network.
Has a unique id and a type.
"""

def __init__(self, id: int, type: str, value: float = 0.0):
self.id = id
self.type = type
Expand Down
106 changes: 106 additions & 0 deletions src/visualize_genome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import networkx as nx
import matplotlib.pyplot as plt
from src.nodes import Genome
import random

# Adjust the size of the visualization whiteboard for the NN:
GRAPH_XMIN = -1.5
GRAPH_XMAX = 17
GRAPH_YMIN = -20
GRAPH_YMAX = 3

def get_position_dict(layers):
"""
Creates a custom layout for the graph G, ensuring nodes are separated by layers.
:param G: The directed graph (DiGraph) representing the neural network or genome.
:param layers: A list of lists, where each inner list contains the nodes in that layer.
:return: A dictionary with node positions suitable for visualization.
"""
pos = {}
layer_gap = 5 # Horizontal gap between layers
node_gap = 2 # Vertical gap between nodes in the same layer

# Total number of layers
total_layers = len(layers)

# Loop through layers and assign positions
for layer_idx, layer in enumerate(layers):
# Special case for the input layer (first layer)
if layer_idx == 0:
x_pos = 0 # Input layer starts at the far left
# Organize the first layer into a 20x10 grid, starting from top-left (0,0)
for i, node in enumerate(layer):
row = i // 20 # There are 10 rows, so row is determined by i // 20
col = i % 20 # Columns are determined by i % 20
pos[node] = (x_pos + col * 0.5, -row * node_gap) # Adjust x (columns) and y (rows)
elif layer_idx == total_layers - 1: # Output layer case
x_pos = total_layers * layer_gap # Place output nodes at the farthest right
y_start = -(len(layer) - 1) * node_gap * 2 # Center the output nodes vertically
for i, node in enumerate(layer):
pos[node] = (x_pos, y_start + i * node_gap) # Place nodes vertically
else:
# Hidden layers are placed regularly between the input and output layers
y_start = -(len(layer) - 1) * node_gap / 2 # Center the layer vertically
for i, node in enumerate(layer):
x_pos = round(random.uniform(10.5, 14.5), 2)
pos[node] = (x_pos, y_start + i * node_gap) # Place nodes vertically

return pos



def visualize_genome(genome: Genome):
G = nx.DiGraph()
add_nodes_to_graph(G, genome)

for connection in genome.connections:
if connection.is_enabled:
G.add_edge(connection.in_node.id, connection.out_node.id, weight = connection.weight)

colors_node = [get_color(node.type, node.value) for node in genome.nodes]

layers = [[] for _ in range(3)]
for node in genome.nodes:
if node.type == 'Input':
layers[0].append(node.id)
elif node.type == 'Hidden':
layers[1].append(node.id)
else:
layers[2].append(node.id)
pos = get_position_dict(layers)
nx.draw(G, pos, with_labels=True, edge_color='b', node_size=500, font_size=8, font_color='w', font_weight='bold', node_color=colors_node)

plt.xlim(GRAPH_XMIN, GRAPH_XMAX)
plt.ylim(GRAPH_YMIN, GRAPH_YMAX)
plt.show()

def add_nodes_to_graph(graph: nx.DiGraph, genome: Genome):
"""
Takes a graph and genome as input, and adds all of the nodes connected to that genome to the graph.
"""
for node in genome.nodes:
if node.type == 'Input':
graph.add_node(node.id, layer_number = 0)
elif node.type == 'Hidden':
graph.add_node(node.id, layer_number = 1)
elif node.type == 'Output':
graph.add_node(node.id, layer_number = 2)

def get_color(type: str, value: float) -> str:
"""
Takes a value which is assumed to be in range [0, 1],
and returns a simple string like 'r' which representsn the color.
"""
if type == 'Input':
if value < 0.25:
return 'b'
elif value < 0.5:
return 'g'
elif value < 0.75:
return 'y'
else:
return 'r'

else:
return 'g'
40 changes: 40 additions & 0 deletions test/visualize_genome_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from src.visualize_genome import visualize_genome
from src.nodes import Genome, Node, ConnectionGene
import random

def generate_nodes():
list_of_nodes = []
for i in range(200):
color = random.random()
list_of_nodes.append(Node(i, "Input", color))
for i in range(200, 202):
color = random.random()
list_of_nodes.append(Node(i, "Hidden", color))
for i in range(202, 207):
color = random.random()
list_of_nodes.append(Node(i, "Output", color))
return list_of_nodes

def create_connections(list_of_nodes):
list_of_connections = []
for i in range(200):
for j in range(200, 202):
list_of_connections.append(ConnectionGene(list_of_nodes[i], list_of_nodes[j], 1, True, 1))
for i in range(202, 207):
for j in range(200, 202):
list_of_connections.append(ConnectionGene(list_of_nodes[j], list_of_nodes[i], 1, True, 1))
return list_of_connections

def test_visualize_genome():
list_of_nodes = generate_nodes()
list_of_connections = create_connections(list_of_nodes)

genome = Genome(1)
for i in list_of_nodes:
genome.add_node(i)
for i in list_of_connections:
genome.add_connection(i)
visualize_genome(genome)



0 comments on commit 0c80470

Please sign in to comment.