From 7a765ed2789f91fc0b9dd8dd0020cb1c9c296f4a Mon Sep 17 00:00:00 2001 From: pjwozny Date: Fri, 3 Mar 2023 10:59:29 +0100 Subject: [PATCH 1/3] removed construct bar chart from eval submissions --- scripts/evaluate_submission.py | 48 ---------------------------------- 1 file changed, 48 deletions(-) diff --git a/scripts/evaluate_submission.py b/scripts/evaluate_submission.py index c95e054..1c3aa1c 100644 --- a/scripts/evaluate_submission.py +++ b/scripts/evaluate_submission.py @@ -40,54 +40,6 @@ import matplotlib.pyplot as plt import pandas as pd -def constructStackedBarChart(global_states, - num_discrete_actions = 10, - field = "minimum_mitigation_rate_all_regions"): - - """ - Constructs a Stacked Bar Chart for each timestep of a given run - Note: this is only useful where there is a single value per country, - ie self-directed country actions, such as mitigation. - - Args: - global_state(dict): the global state dictionary of a completed run - field(str): the name of the field to extract, defualt is minimum_mitigation_rates - wandb(wandb-object): an already initialized and open wandb object - indicating where the plot should be sent - """ - - - rates_over_time = global_states[field] - possible_rates = range(0,num_discrete_actions-1) - to_plot = {rate:[] for rate in possible_rates} - - #per timestep get rate counts - for timestep in range(rates_over_time.shape[0]): - current_rates = rates_over_time[timestep,:] - current_counter = Counter(current_rates) - for rate in possible_rates: - #count countries with a particular rate - if rate in current_counter.keys(): - to_plot[rate].append(current_counter[rate]) - #if no countries have that particular rate - else: - to_plot[rate].append(0) - - pdf = pd.DataFrame(to_plot) - pdf.plot(kind='bar', stacked=True).legend(loc='center left',bbox_to_anchor=(1.0, 0.5)) - plt.xlabel(f"Countries of a Given {field}") - plt.ylabel("Timesteps") - plt.title(f"{field} Distribution") - plt.show() - # wandb.log({f"{field}":plt}) - - - - - - - - # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. logging.getLogger().setLevel(logging.ERROR) From 2ef9441dd4593cad8081ba56383187766a694a65 Mon Sep 17 00:00:00 2001 From: pjwozny Date: Fri, 3 Mar 2023 11:03:46 +0100 Subject: [PATCH 2/3] adhered to style guidelines and changed return object of viz function --- scripts/evaluate_submission.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/evaluate_submission.py b/scripts/evaluate_submission.py index 1c3aa1c..3e52bcc 100644 --- a/scripts/evaluate_submission.py +++ b/scripts/evaluate_submission.py @@ -23,6 +23,7 @@ import yaml import pickle as pkl from pathlib import Path +from visualizeOutputs import construct_stacked_bar_chart _path = Path(os.path.abspath(__file__)) @@ -289,12 +290,12 @@ def compute_metrics(fetch_episode_states, trainer, framework, submission_file, e pkl.dump(episode_states[0], f, protocol=pkl.HIGHEST_PROTOCOL) #log mitigation rate counts of each country over time - constructStackedBarChart(episode_states[0],wandb, - field="mitigation_rate_all_regions") + wandb.log({"mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="mitigation_rate_all_regions")}) - #log minimum mitigation rate counts of each country over time - constructStackedBarChart(episode_states[0], wandb, - field="minimum_mitigation_rate_all_regions") + #log mitigation rate counts of each country over time + wandb.log({"minimum_mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="minimum_mitigation_rate_all_regions")}) for feature in desired_outputs: feature_values = [None for _ in range(num_episodes)] From abc93747b8c30f27a0eb7e96106d016bdbb1ffce Mon Sep 17 00:00:00 2001 From: pjwozny Date: Fri, 3 Mar 2023 11:05:07 +0100 Subject: [PATCH 3/3] added log check --- scripts/evaluate_submission.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/evaluate_submission.py b/scripts/evaluate_submission.py index 3e52bcc..6a99824 100644 --- a/scripts/evaluate_submission.py +++ b/scripts/evaluate_submission.py @@ -288,14 +288,16 @@ def compute_metrics(fetch_episode_states, trainer, framework, submission_file, e with open("episode_states.pkl", "wb") as f: pkl.dump(episode_states[0], f, protocol=pkl.HIGHEST_PROTOCOL) + + if log_config and log_config["enabled"]: - #log mitigation rate counts of each country over time - wandb.log({"mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], - field="mitigation_rate_all_regions")}) + #log mitigation rate counts of each country over time + wandb.log({"mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="mitigation_rate_all_regions")}) - #log mitigation rate counts of each country over time - wandb.log({"minimum_mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], - field="minimum_mitigation_rate_all_regions")}) + #log mitigation rate counts of each country over time + wandb.log({"minimum_mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="minimum_mitigation_rate_all_regions")}) for feature in desired_outputs: feature_values = [None for _ in range(num_episodes)]