Skip to content

Commit

Permalink
Merge pull request #11 from pjwozny/vizclean
Browse files Browse the repository at this point in the history
Vizclean
  • Loading branch information
brenting authored Mar 5, 2023
2 parents a8e2878 + abc9374 commit cd1db49
Showing 1 changed file with 9 additions and 54 deletions.
63 changes: 9 additions & 54 deletions scripts/evaluate_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import yaml
import pickle as pkl
from pathlib import Path
from visualizeOutputs import construct_stacked_bar_chart

_path = Path(os.path.abspath(__file__))

Expand All @@ -40,54 +41,6 @@
import matplotlib.pyplot as plt
import pandas as pd

def constructStackedBarChart(global_states,
num_discrete_actions = 10,
field = "minimum_mitigation_rate_all_regions"):

"""
Constructs a Stacked Bar Chart for each timestep of a given run
Note: this is only useful where there is a single value per country,
ie self-directed country actions, such as mitigation.
Args:
global_state(dict): the global state dictionary of a completed run
field(str): the name of the field to extract, defualt is minimum_mitigation_rates
wandb(wandb-object): an already initialized and open wandb object
indicating where the plot should be sent
"""


rates_over_time = global_states[field]
possible_rates = range(0,num_discrete_actions-1)
to_plot = {rate:[] for rate in possible_rates}

#per timestep get rate counts
for timestep in range(rates_over_time.shape[0]):
current_rates = rates_over_time[timestep,:]
current_counter = Counter(current_rates)
for rate in possible_rates:
#count countries with a particular rate
if rate in current_counter.keys():
to_plot[rate].append(current_counter[rate])
#if no countries have that particular rate
else:
to_plot[rate].append(0)

pdf = pd.DataFrame(to_plot)
pdf.plot(kind='bar', stacked=True).legend(loc='center left',bbox_to_anchor=(1.0, 0.5))
plt.xlabel(f"Countries of a Given {field}")
plt.ylabel("Timesteps")
plt.title(f"{field} Distribution")
plt.show()
# wandb.log({f"{field}":plt})








# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
logging.getLogger().setLevel(logging.ERROR)

Expand Down Expand Up @@ -335,14 +288,16 @@ def compute_metrics(fetch_episode_states, trainer, framework, submission_file, e

with open("episode_states.pkl", "wb") as f:
pkl.dump(episode_states[0], f, protocol=pkl.HIGHEST_PROTOCOL)

if log_config and log_config["enabled"]:

#log mitigation rate counts of each country over time
constructStackedBarChart(episode_states[0],wandb,
field="mitigation_rate_all_regions")
#log mitigation rate counts of each country over time
wandb.log({"mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0],
field="mitigation_rate_all_regions")})

#log minimum mitigation rate counts of each country over time
constructStackedBarChart(episode_states[0], wandb,
field="minimum_mitigation_rate_all_regions")
#log mitigation rate counts of each country over time
wandb.log({"minimum_mitigation_rate Counts Across Time":construct_stacked_bar_chart(episode_states[0],
field="minimum_mitigation_rate_all_regions")})

for feature in desired_outputs:
feature_values = [None for _ in range(num_episodes)]
Expand Down

0 comments on commit cd1db49

Please sign in to comment.