From 95c6d6ddbb1974972a2091488be6cacdd0aabf57 Mon Sep 17 00:00:00 2001 From: IcyBroom Date: Fri, 18 Oct 2024 00:12:40 -0400 Subject: [PATCH 1/4] Update fl.py Updated fl.py file to run plot time using tensorboard --- src/algos/fl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/algos/fl.py b/src/algos/fl.py index 625a61ae..029f78ec 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -45,6 +45,9 @@ def local_train(self, round: int, **kwargs: Any): self.log_utils.log_tb( f"train_accuracy/client{self.node_id}", avg_accuracy, round ) + self.log_utils.log_tb( + f"time/client{self.node_id}", time_taken, round + ) def local_test(self, **kwargs: Any): """ @@ -177,6 +180,7 @@ def run_protocol(self): loss, acc, time_taken = self.test() self.log_utils.log_tb("test_acc/clients", acc, round) self.log_utils.log_tb("test_loss/clients", loss, round) + self.log_utils.log_tb("time/clients", time_taken, round) self.log_utils.log_console( "Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format( round, acc, loss, time_taken From 9f675a6d3abc8a77b0716041efc6e1e2c5408117 Mon Sep 17 00:00:00 2001 From: IcyBroom Date: Fri, 18 Oct 2024 00:15:32 -0400 Subject: [PATCH 2/4] Update fl.py --- src/algos/fl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/algos/fl.py b/src/algos/fl.py index 029f78ec..ef765d25 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -168,6 +168,9 @@ def single_round(self): self.set_representation(avg_wts) def run_protocol(self): + """ + Run the federated averaging protocol + """ self.log_utils.log_console("Starting clients federated averaging") start_rounds = self.config.get("start_rounds", 0) total_rounds = self.config["rounds"] From 3aa705a563f98971601ec060091023b3ba8be7b0 Mon Sep 17 00:00:00 2001 From: IcyBroom Date: Sun, 10 Nov 2024 09:54:18 -0500 Subject: [PATCH 3/4] Update log_utils.py --- src/utils/log_utils.py | 241 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 240 insertions(+), 1 deletion(-) diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index 62f015aa..beed28b2 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -17,7 +17,10 @@ import pandas as pd from utils.types import ConfigType import json - +import networkx as nx +from networkx import Graph +import matplotlib.pyplot as plt +import imageio def deprocess(img: torch.Tensor) -> torch.Tensor: """ @@ -123,7 +126,91 @@ def __init__(self, config: ConfigType) -> None: self.init_npy() self.init_summary() self.init_csv() + self.init_nx_graph(config) + self.nx_layout = None + + def init_nx_graph(self, config: ConfigType): + """ + Initialize the networkx graph for the topology. + + Args: + config (ConfigType): Configuration dictionary. + rank (int): Rank of the current node. + """ + self.topology = config["topology"] + self.num_users = config["num_users"] + self.graph = nx.Graph() + + + # def generate_graph(self): + # """ + # Generate the graph using the networkX library + # and store it in the self.graph attribute. + # NetworkX has a lot of built-in functions to generate graphs. + # Use this url - https://networkx.org/documentation/stable/reference/generators.html + # """ + # if self.topology["name"] == "centralized": + # self.graph = nx.complete_graph(self.num_users) + # else: + # raise ValueError("Invalid topology name") + + def log_nx_graph(self, graph: Graph, iteration: int, directory: str|None = None): + """ + Log the networkx graph to a file. + """ + # print(graph) + if directory: + nx.write_adjlist(graph, f"{directory}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore + else: + nx.write_adjlist(graph, f"{self.log_dir}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore + + + def log_nx_graph_image(self, graph: Graph, iteration: int, directory: str|None = None): + """ + Log the networkx graph as an image. + """ + # Generate a layout for the graph + if self.nx_layout is None: # type: ignore + self.nx_layout = nx.spring_layout(graph) # type: ignore + + # pos = nx.spring_layout(graph) + + # Draw the graph with labels + nx.draw(graph, self.nx_layout, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray")# type: ignore + + # Save the plot as an image + if directory : + plt.savefig(f"{directory}/graph_{iteration}.png", format="png") # type: ignore + else: + plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png")# type: ignore + + # Close the plot to free up memory + plt.close()# type: ignore + + def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str|None = None): + """ + Log the networkx graph with edge weights as an image. + """ + # Define position of nodes for layout + pos = nx.spring_layout(graph) + + # Get the edge weights + edge_weights = nx.get_edge_attributes(graph, "weight") + + # Draw the graph with edge weights + nx.draw(graph, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray", width=[float(edge_weights[edge]) for edge in graph.edges()])# type: ignore + + # Save the plot as an image + if directory: + plt.savefig(f"{directory}/graph_{iteration}.png", format="png") + else: + plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png") + + # Close the plot to free up memory + plt.close() + + def log_config(self, config: ConfigType): """ Log the configuration to a json file. @@ -165,6 +252,15 @@ def init_csv(self): if not os.path.exists(csv_path) or not os.path.isdir(csv_path): os.makedirs(csv_path) + parent = os.path.dirname(self.log_dir) + "/csv" # type: ignore + if not os.path.exists(parent) or not os.path.isdir(parent): # type: ignore + os.makedirs(parent) # type: ignore + + imgs = parent + "/imgs" + if not os.path.exists(imgs) or not os.path.isdir(imgs): # type: ignore + os.makedirs(imgs) # type: ignore + + def log_summary(self, text: str): """ Add summary text to the summary file for logging. @@ -239,6 +335,149 @@ def log_csv(self, key: str, value: Any, iteration: int): # Append the metrics to the CSV file df.to_csv(log_file, mode='a', header=not file_exists, index=False) + #make a global file to store all the neighbors of each round + if key == "neighbors": + self.log_global_csv(iteration, key, value) + + def log_global_csv(self, iteration: int, key: str, value: Any): + """ + Log a value to a CSV file. + """ + parent = os.path.dirname(self.log_dir) # type: ignore + log_file = f"{parent}/csv/neighbors_{iteration}.csv" + node = self.log_dir.split("_")[-1] # type: ignore + row = {"iteration": iteration, "node": node , key: value} + df = pd.DataFrame([row]) + file_exists = os.path.isfile(log_file) + df.to_csv(log_file, mode='a', header=not file_exists, index=False) + + if len(pd.read_csv(log_file)) == self.num_users: + adjacency_list = self.create_adjacency_list(log_file) + graph = nx.Graph(adjacency_list) + # create the /img directory + self.log_nx_graph_image(graph, iteration, f"{parent}/csv/imgs") + self.log_nx_graph(graph, iteration, f"{parent}/csv") + self.combine_graphs_with_edge_frequency(f"{parent}/csv") + + + def combine_graphs_with_edge_frequency(self, directory: str): + """ + Combine the adjacency lists of all the rounds and calculate the edge frequency. + """ + # Get all the adjacency lists + adjacency_lists = glob(f"{directory}/*.csv") + # Initialize the edge frequency dictionary + edge_frequency = {} + # Initialize the adjacency list + adjacency_list = {} + # Iterate over all the adjacency lists + for adj_list in adjacency_lists: + # Load the adjacency list + data = pd.read_csv(adj_list) + # Populate the adjacency list + for _, row in data.iterrows(): + node = row["node"] + # Convert string representation of list to actual list + neighbors = eval(row["neighbors"]) + if node not in adjacency_list: + adjacency_list[node] = neighbors + else: + adjacency_list[node].extend(neighbors) + # Calculate the edge frequency + for node, neighbors in adjacency_list.items(): + for neighbor in neighbors: + if (node, neighbor) in edge_frequency: + edge_frequency[(node, neighbor)] += 1 + else: + edge_frequency[(node, neighbor)] = 1 + + # create a graph with edges in the edge frequency with higher frequency edges having thicker lines + G = nx.Graph() + for edge, freq in edge_frequency.items(): + # print(edge, freq) + G.add_edge(edge[0], edge[1], weight=freq) + + self.log_nx_graph_edge_weights(G, -1, directory) + + # create the /img directory + #log graph image with edge weights + self.log_nx_graph(G, -1, directory) + + # create video of the graphs + self.create_video(directory + "/imgs") + print(edge_frequency) + #create a heatmap of the edge frequency + # self.create_heatmap(edge_frequency, directory) # type: ignore + + def create_video(self, directory: str): + """ + Create a video of the graphs using imageIO where each image is a second long. + """ + images = [] + for filename in sorted(glob(f"{directory}/graph_*.png")): + images.append(imageio.imread(filename)) + # make the gif loop + imageio.mimsave(f"{directory}/graph_video.gif", images, fps =1, loop =0) + + + # def create_heatmap(self, edge_frequency, directory: str): + # """ + # Create a heatmap of the edge frequency. + # """ + # # Initialize the edge frequency matrix + # edge_frequency_matrix = np.zeros((self.num_users, self.num_users)) + + # # Populate the edge frequency matrix + # for edge, freq in edge_frequency.items(): + # edge_frequency_matrix[edge[0], edge[1]] = freq + # edge_frequency_matrix[edge[1], edge[0]] = freq # Ensure symmetry for undirected graph + + # edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization + # # Create the heatmap + # plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest") + # # plt.colorbar(label="Frequency of Communication") + + # # Save the heatmap + # plt.savefig(f"{directory}/edge_frequency_heatmap.png") + # plt.close() + + + + + def create_adjacency_list(self, file_path: str) -> Dict[str, list]: # type: ignore + # Load the CSV file + """ + Load the CSV file, populate the adjacency list and return it. + + Parameters + ---------- + file_path : str + The path to the CSV file + + Returns + ------- + adjacency_list : Dict[str, list] + The adjacency list + """ + data = pd.read_csv(str(file_path)) # type: ignore + + # Initialize the adjacency list + adjacency_list : Dict[str, list] = {} # type: ignore + + # Populate the adjacency list + for _, row in data.iterrows(): # type: ignore + node = row["node"] # type: ignore + # Convert string representation of list to actual list + neighbors = eval(row["neighbors"]) # type: ignore + + if node not in adjacency_list: + adjacency_list[node] = neighbors + else: + adjacency_list[node].extend(neighbors) # type: ignore + + return adjacency_list # type: ignore + + def log_max_stats_per_client( self, stats_per_client: np.ndarray, round_step: int, metric: str From 02074896e5cabc1c2414f70c2717f57e1f2ab78f Mon Sep 17 00:00:00 2001 From: IcyBroom Date: Mon, 11 Nov 2024 03:39:17 -0500 Subject: [PATCH 4/4] Logging --- src/algos/base_class.py | 7 ++ src/utils/log_utils.py | 195 ++++++++++++------------------------ src/utils/plot_utils.py | 216 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 132 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 7fcd1292..f923946b 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -307,6 +307,13 @@ def log_metrics(self, stats: Dict[str, Any], iteration: int) -> None: imgs=stats["images"], key="sample_images", iteration=iteration ) + #config["num_users"] + #Check if the file neighbors_{iteration}.csv exists in logs/csv + dir = os.path.dirname(self.log_utils.log_dir) + "/csv" + if os.path.exists(f"{dir}/graph_{iteration}.adjlist"): + self.plot_utils.log_nx_graph_image(iteration, f"{dir}/neighbors_{iteration}.csv") + self.plot_utils.combine_graphs_with_edge_frequency(dir, iteration) + @abstractmethod def receive_and_aggregate(self): """Add docstring here""" diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index beed28b2..767d17a7 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -106,6 +106,7 @@ class LogUtils: """ Utility class for logging and saving experiment data. """ + # nx_layout = None def __init__(self, config: ConfigType) -> None: log_dir = config["log_path"] @@ -126,8 +127,8 @@ def __init__(self, config: ConfigType) -> None: self.init_npy() self.init_summary() self.init_csv() - self.init_nx_graph(config) self.nx_layout = None + self.init_nx_graph(config) def init_nx_graph(self, config: ConfigType): """ @@ -137,22 +138,12 @@ def init_nx_graph(self, config: ConfigType): config (ConfigType): Configuration dictionary. rank (int): Rank of the current node. """ - self.topology = config["topology"] + if "topology" in config: + self.topology = config["topology"] self.num_users = config["num_users"] - self.graph = nx.Graph() + self.graph = nx.DiGraph() - # def generate_graph(self): - # """ - # Generate the graph using the networkX library - # and store it in the self.graph attribute. - # NetworkX has a lot of built-in functions to generate graphs. - # Use this url - https://networkx.org/documentation/stable/reference/generators.html - # """ - # if self.topology["name"] == "centralized": - # self.graph = nx.complete_graph(self.num_users) - # else: - # raise ValueError("Invalid topology name") def log_nx_graph(self, graph: Graph, iteration: int, directory: str|None = None): """ @@ -165,51 +156,80 @@ def log_nx_graph(self, graph: Graph, iteration: int, directory: str|None = None) nx.write_adjlist(graph, f"{self.log_dir}/graph_{iteration}.adjlist", comments='#', delimiter=' ', encoding='utf-8') # type: ignore - def log_nx_graph_image(self, graph: Graph, iteration: int, directory: str|None = None): - """ - Log the networkx graph as an image. - """ - # Generate a layout for the graph - if self.nx_layout is None: # type: ignore - self.nx_layout = nx.spring_layout(graph) # type: ignore - - # pos = nx.spring_layout(graph) + def log_nx_graph_image(self, graph: Graph, iteration: int, directory: str | None = None): + """ + Log the networkx directed graph as an image with non-overlapping edges. + """ + # Generate a layout with more spacing + if self.nx_layout is None: + self.nx_layout = nx.shell_layout(graph) + + # Draw nodes with larger size and smaller font + nx.draw_networkx_nodes(graph, self.nx_layout, node_size=700, node_color="skyblue") + + # Draw each edge with curved lines for side-by-side display + edges = list(graph.edges()) + for i, (u, v) in enumerate(edges): + rad = 0.2 if i % 2 == 0 else -0.2 # Increase rad for more curvature + nx.draw_networkx_edges( + graph, + self.nx_layout, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + arrows=True, + # arrowstyle="-|>", # Customize arrow style for better separation + arrowsize=20 # Increase arrow size for visibility + ) + + # Draw labels with smaller font + nx.draw_networkx_labels(graph, self.nx_layout, font_size=8, font_weight="bold") - # Draw the graph with labels - nx.draw(graph, self.nx_layout, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray")# type: ignore - # Save the plot as an image - if directory : - plt.savefig(f"{directory}/graph_{iteration}.png", format="png") # type: ignore + if directory: + plt.savefig(f"{directory}/graph_{iteration}.png", format="png") else: - plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png")# type: ignore - + plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png") + # Close the plot to free up memory - plt.close()# type: ignore + plt.close() - def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str|None = None): + + + def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str | None = None): """ - Log the networkx graph with edge weights as an image. + Log the directed graph with edge weights as an image with non-overlapping edges. """ - # Define position of nodes for layout - pos = nx.spring_layout(graph) - # Get the edge weights + if self.nx_layout is None: + self.nx_layout = nx.shell_layout(graph) + pos = self.nx_layout edge_weights = nx.get_edge_attributes(graph, "weight") - # Draw the graph with edge weights - nx.draw(graph, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray", width=[float(edge_weights[edge]) for edge in graph.edges()])# type: ignore + for i, (u, v) in enumerate(graph.edges()): + rad = 0.2 if i % 2 == 0 else -0.2 + nx.draw_networkx_edges( + graph, + pos, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + arrows=True, + width=edge_weights.get((u, v), 1.0), # Set width based on weight + arrowsize=20 + ) + + # Draw nodes and labels + nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue") + nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") - # Save the plot as an image if directory: - plt.savefig(f"{directory}/graph_{iteration}.png", format="png") + plt.savefig(f"{directory}/weighted_graph_{iteration}.png", format="png") else: - plt.savefig(f"{self.log_dir}/graph_{iteration}.png", format="png") + plt.savefig(f"{self.log_dir}/weighted_graph_{iteration}.png", format="png") - # Close the plot to free up memory plt.close() + def log_config(self, config: ConfigType): """ @@ -256,10 +276,6 @@ def init_csv(self): if not os.path.exists(parent) or not os.path.isdir(parent): # type: ignore os.makedirs(parent) # type: ignore - imgs = parent + "/imgs" - if not os.path.exists(imgs) or not os.path.isdir(imgs): # type: ignore - os.makedirs(imgs) # type: ignore - def log_summary(self, text: str): """ @@ -353,93 +369,8 @@ def log_global_csv(self, iteration: int, key: str, value: Any): if len(pd.read_csv(log_file)) == self.num_users: adjacency_list = self.create_adjacency_list(log_file) - graph = nx.Graph(adjacency_list) - # create the /img directory - self.log_nx_graph_image(graph, iteration, f"{parent}/csv/imgs") + graph = nx.DiGraph(adjacency_list) self.log_nx_graph(graph, iteration, f"{parent}/csv") - self.combine_graphs_with_edge_frequency(f"{parent}/csv") - - - def combine_graphs_with_edge_frequency(self, directory: str): - """ - Combine the adjacency lists of all the rounds and calculate the edge frequency. - """ - # Get all the adjacency lists - adjacency_lists = glob(f"{directory}/*.csv") - # Initialize the edge frequency dictionary - edge_frequency = {} - # Initialize the adjacency list - adjacency_list = {} - # Iterate over all the adjacency lists - for adj_list in adjacency_lists: - # Load the adjacency list - data = pd.read_csv(adj_list) - # Populate the adjacency list - for _, row in data.iterrows(): - node = row["node"] - # Convert string representation of list to actual list - neighbors = eval(row["neighbors"]) - if node not in adjacency_list: - adjacency_list[node] = neighbors - else: - adjacency_list[node].extend(neighbors) - # Calculate the edge frequency - for node, neighbors in adjacency_list.items(): - for neighbor in neighbors: - if (node, neighbor) in edge_frequency: - edge_frequency[(node, neighbor)] += 1 - else: - edge_frequency[(node, neighbor)] = 1 - - # create a graph with edges in the edge frequency with higher frequency edges having thicker lines - G = nx.Graph() - for edge, freq in edge_frequency.items(): - # print(edge, freq) - G.add_edge(edge[0], edge[1], weight=freq) - - self.log_nx_graph_edge_weights(G, -1, directory) - - # create the /img directory - #log graph image with edge weights - self.log_nx_graph(G, -1, directory) - - # create video of the graphs - self.create_video(directory + "/imgs") - print(edge_frequency) - #create a heatmap of the edge frequency - # self.create_heatmap(edge_frequency, directory) # type: ignore - - def create_video(self, directory: str): - """ - Create a video of the graphs using imageIO where each image is a second long. - """ - images = [] - for filename in sorted(glob(f"{directory}/graph_*.png")): - images.append(imageio.imread(filename)) - # make the gif loop - imageio.mimsave(f"{directory}/graph_video.gif", images, fps =1, loop =0) - - - # def create_heatmap(self, edge_frequency, directory: str): - # """ - # Create a heatmap of the edge frequency. - # """ - # # Initialize the edge frequency matrix - # edge_frequency_matrix = np.zeros((self.num_users, self.num_users)) - - # # Populate the edge frequency matrix - # for edge, freq in edge_frequency.items(): - # edge_frequency_matrix[edge[0], edge[1]] = freq - # edge_frequency_matrix[edge[1], edge[0]] = freq # Ensure symmetry for undirected graph - - # edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization - # # Create the heatmap - # plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest") - # # plt.colorbar(label="Frequency of Communication") - - # # Save the heatmap - # plt.savefig(f"{directory}/edge_frequency_heatmap.png") - # plt.close() diff --git a/src/utils/plot_utils.py b/src/utils/plot_utils.py index 6cbd0e23..bda07702 100644 --- a/src/utils/plot_utils.py +++ b/src/utils/plot_utils.py @@ -3,6 +3,15 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec +from typing import Dict +import networkx as nx +from networkx import Graph +from glob import glob +from utils.types import ConfigType +import imageio +import pandas as pd + + from mpl_toolkits.axes_grid1 import AxesGrid import math @@ -19,6 +28,213 @@ def __init__(self, config, with_title=True) -> None: self.config = config self.with_title = with_title + self.nx_layout = None + # self.init_nx_graph(config) + + + + # def init_nx_graph(self, config: ConfigType): + # """ + # Initialize the networkx graph for the topology. + + # Args: + # config (ConfigType): Configuration dictionary. + # rank (int): Rank of the current node. + # """ + # self.num_users = config["num_users"] + # self.graph = nx.DiGraph() + + def create_adjacency_list(self, file_path: str) -> Dict[str, list]: # type: ignore + # Load the CSV file + """ + Load the CSV file, populate the adjacency list and return it. + + Parameters + ---------- + file_path : str + The path to the CSV file + + Returns + ------- + adjacency_list : Dict[str, list] + The adjacency list + """ + data = pd.read_csv(str(file_path)) # type: ignore + + # Initialize the adjacency list + adjacency_list : Dict[str, list] = {} # type: ignore + + # Populate the adjacency list + for _, row in data.iterrows(): # type: ignore + node = row["node"] # type: ignore + # Convert string representation of list to actual list + neighbors = eval(row["neighbors"]) # type: ignore + + if node not in adjacency_list: + adjacency_list[node] = neighbors + else: + adjacency_list[node].extend(neighbors) # type: ignore + + return adjacency_list # type: ignore + + def log_nx_graph_image(self, iteration: int, directory: str, output_dir: str| None = None): + """ + Log the networkx directed graph as an image with non-overlapping edges. + """ + + adjacency_list = self.create_adjacency_list(directory) + graph = nx.DiGraph(adjacency_list) + # Generate a layout with more spacing + if self.nx_layout is None: + self.nx_layout = nx.shell_layout(graph, nlist=[range(1, self.num_users+1)]) # Increase spacing with `k` + + # Draw nodes with larger size and smaller font + nx.draw_networkx_nodes(graph, self.nx_layout, node_size=700, node_color="skyblue") + + # Draw each edge with curved lines for side-by-side display + edges = list(graph.edges()) + for i, (u, v) in enumerate(edges): + rad = 0.2 if i % 2 == 0 else -0.2 # Increase rad for more curvature + nx.draw_networkx_edges( + graph, + self.nx_layout, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + arrows=True, + # arrowstyle="-|>", # Customize arrow style for better separation + arrowsize=20 # Increase arrow size for visibility + ) + + # Draw labels with smaller font + nx.draw_networkx_labels(graph, self.nx_layout, font_size=8, font_weight="bold") + + # Save the plot as an image + if output_dir: + plt.savefig(f"{directory}/graph_{iteration}.png", format="png") + else: + plt.savefig(f"{self.plot_dir}/graph_{iteration}.png", format="png") + + # Close the plot to free up memory + plt.close() + + + + def log_nx_graph_edge_weights(self, graph: Graph, iteration: int, directory: str | None = None): + """ + Log the directed graph with edge weights as an image with non-overlapping edges. + """ + + if self.nx_layout is None: + self.nx_layout = nx.shell_layout(graph , nlist=[range(1, self.num_users+1)]) + pos = self.nx_layout + edge_weights = nx.get_edge_attributes(graph, "weight") + + for i, (u, v) in enumerate(graph.edges()): + rad = 0.2 if i % 2 == 0 else -0.2 + nx.draw_networkx_edges( + graph, + pos, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + arrows=True, + width=edge_weights.get((u, v), 1.0), # Set width based on weight + arrowsize=20 + ) + + # Draw nodes and labels + nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue") + nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") + + if directory: + plt.savefig(f"{directory}/weighted_graph_{iteration}.png", format="png") + else: + plt.savefig(f"{self.plot_dir}/weighted_graph_{iteration}.png", format="png") + + plt.close() + + def combine_graphs_with_edge_frequency(self, directory: str, iteration: int, output_dir: str | None = None): + """ + Combine the adjacency lists of all the rounds and calculate the edge frequency. + """ + # Get all the adjacency lists + adjacency_lists = glob(f"{directory}/*.csv") + # Initialize the edge frequency dictionary + edge_frequency = {} + # Initialize the adjacency list + adjacency_list = {} + # Iterate over all the adjacency lists + for adj_list in adjacency_lists: + # Load the adjacency list + data = pd.read_csv(adj_list) + # Populate the adjacency list + for _, row in data.iterrows(): + node = row["node"] + # Convert string representation of list to actual list + neighbors = eval(row["neighbors"]) + if node not in adjacency_list: + adjacency_list[node] = neighbors + else: + adjacency_list[node].extend(neighbors) + # Calculate the edge frequency + for node, neighbors in adjacency_list.items(): + for neighbor in neighbors: + if (node, neighbor) in edge_frequency: + edge_frequency[(node, neighbor)] += 1 + else: + edge_frequency[(node, neighbor)] = 1 + + # create a graph with edges in the edge frequency with higher frequency edges having thicker lines + G = nx.DiGraph() + for edge, freq in edge_frequency.items(): + # print(edge, freq) + G.add_edge(edge[0], edge[1], weight=freq) + + self.log_nx_graph_edge_weights(G, iteration) + + # create video of the graphs + if output_dir: + self.create_video(output_dir) + self.create_heatmap(edge_frequency, output_dir) + else: + self.create_video(self.plot_dir) + self.create_heatmap(edge_frequency, self.plot_dir) + + def create_video(self, directory: str): + """ + Create a video of the graphs using imageIO where each image is a second long. + """ + images = [] + for filename in sorted(glob(f"{directory}/graph_*.png")): + images.append(imageio.imread(filename)) + # make the gif loop + imageio.mimsave(f"{directory}/graph_video.gif", images, fps =1, loop =0) + + + def create_heatmap(self, edge_frequency, directory: str): + """ + Create a heatmap of the edge frequency. + """ + # Initialize the edge frequency matrix + edge_frequency_matrix = np.zeros((self.num_users+1, self.num_users+1)) + + # Populate the edge frequency matrix + for edge, freq in edge_frequency.items(): + edge_frequency_matrix[edge[0]][edge[1]] = freq + # edge_frequency_matrix[edge[1], edge[0]] = freq # Ensure symmetry for undirected graph + + edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization + # Create the heatmap + plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest") + plt.colorbar(label="Frequency of Communication") + plt.title("Edge Frequency Heatmap") + plt.xlabel("Node") + plt.ylabel("Node") + plt.xticks(range(1, self.num_users+1)) + plt.yticks(range(1, self.num_users+1)) + + # Save the heatmap + plt.savefig(f"{directory}/edge_frequency_heatmap.png") + plt.close() def get_dataset_config_string(self): if isinstance(self.config["dset"], dict):