diff --git a/scripts/evaluate_submission.py b/scripts/evaluate_submission.py index c95e054..6a99824 100644 --- a/scripts/evaluate_submission.py +++ b/scripts/evaluate_submission.py @@ -23,6 +23,7 @@ import yaml import pickle as pkl from pathlib import Path +from visualizeOutputs import construct_stacked_bar_chart _path = Path(os.path.abspath(__file__)) @@ -40,54 +41,6 @@ import matplotlib.pyplot as plt import pandas as pd -def constructStackedBarChart(global_states, - num_discrete_actions = 10, - field = "minimum_mitigation_rate_all_regions"): - - """ - Constructs a Stacked Bar Chart for each timestep of a given run - Note: this is only useful where there is a single value per country, - ie self-directed country actions, such as mitigation. - - Args: - global_state(dict): the global state dictionary of a completed run - field(str): the name of the field to extract, defualt is minimum_mitigation_rates - wandb(wandb-object): an already initialized and open wandb object - indicating where the plot should be sent - """ - - - rates_over_time = global_states[field] - possible_rates = range(0,num_discrete_actions-1) - to_plot = {rate:[] for rate in possible_rates} - - #per timestep get rate counts - for timestep in range(rates_over_time.shape[0]): - current_rates = rates_over_time[timestep,:] - current_counter = Counter(current_rates) - for rate in possible_rates: - #count countries with a particular rate - if rate in current_counter.keys(): - to_plot[rate].append(current_counter[rate]) - #if no countries have that particular rate - else: - to_plot[rate].append(0) - - pdf = pd.DataFrame(to_plot) - pdf.plot(kind='bar', stacked=True).legend(loc='center left',bbox_to_anchor=(1.0, 0.5)) - plt.xlabel(f"Countries of a Given {field}") - plt.ylabel("Timesteps") - plt.title(f"{field} Distribution") - plt.show() - # wandb.log({f"{field}":plt}) - - - - - - - - # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. logging.getLogger().setLevel(logging.ERROR) @@ -335,14 +288,16 @@ def compute_metrics(fetch_episode_states, trainer, framework, submission_file, e with open("episode_states.pkl", "wb") as f: pkl.dump(episode_states[0], f, protocol=pkl.HIGHEST_PROTOCOL) + + if log_config and log_config["enabled"]: - #log mitigation rate counts of each country over time - constructStackedBarChart(episode_states[0],wandb, - field="mitigation_rate_all_regions") + #log mitigation rate counts of each country over time + wandb.log({"mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="mitigation_rate_all_regions")}) - #log minimum mitigation rate counts of each country over time - constructStackedBarChart(episode_states[0], wandb, - field="minimum_mitigation_rate_all_regions") + #log mitigation rate counts of each country over time + wandb.log({"minimum_mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0], + field="minimum_mitigation_rate_all_regions")}) for feature in desired_outputs: feature_values = [None for _ in range(num_episodes)]