Skip to content

Commit

Permalink
Logs processing and README update (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
DT6A authored Jun 14, 2023
1 parent 52aa3cf commit 33eba94
Show file tree
Hide file tree
Showing 12 changed files with 3,943 additions and 69 deletions.
226 changes: 157 additions & 69 deletions README.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions results/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Reproducing figures and tables

To reproduce all figures and tables from the paper, do the following steps:
1. Run `get_{offline, finetune}_urls.py` if needed. These scripts collect all wandb logs into .csv files and save them into the `runs_tables` folder. We provide the tables, but you can recollect them.
2. Run `get_{offline, finetune}_scores.py` if needed. These scripts collect data from runs kept in .csv files and save evaluation scores (and regret in case of offline-to-online) into pickled files, which are stored in the `bin` folder. We provide the pickled data, but if you need to extract more data, you can modify scripts for your purposes.
3. Run `get_{offline, finetune}_tables_and_plots.py`. These scripts use pickled data, print all the tables, and save all figures into the `out` directory.
Binary file added results/bin/finetune_scores.pickle
Binary file not shown.
Binary file added results/bin/offline_scores.pickle
Binary file not shown.
59 changes: 59 additions & 0 deletions results/get_finetune_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import pickle

import pandas as pd
from tqdm import tqdm
import wandb

dataframe = pd.read_csv("runs_tables/finetune_urls.csv")

api = wandb.Api(timeout=29)


def get_run_scores(run_id, is_dt=False):
run = api.run(run_id)
score_key = None
full_scores = []
regret = None
max_dt = -1e10

for k in run.history().keys():
if "normalized" in k and "score" in k and "std" not in k:
if is_dt:
st = k
if "eval/" in st:
st = st.replace("eval/", "")
target = float(st.split("_")[0])
if target > max_dt:
max_dt = target
score_key = k
else:
score_key = k
break
for _, row in run.history(keys=[score_key], samples=5000).iterrows():
full_scores.append(row[score_key])
for _, row in run.history(keys=["eval/regret"], samples=5000).iterrows():
if "eval/regret" in row:
regret = row["eval/regret"]
offline_iters = len(full_scores) // 2
return full_scores[:offline_iters], full_scores[offline_iters:], regret


def process_runs(df):
algorithms = df["algorithm"].unique()
datasets = df["dataset"].unique()
full_scores = {algo: {ds: [] for ds in datasets} for algo in algorithms}
for _, row in tqdm(
df.iterrows(), desc="Runs scores downloading", position=0, leave=True
):
full_scores[row["algorithm"]][row["dataset"]].append(
get_run_scores(row["url"], row["algorithm"] == "DT")
)
return full_scores


full_scores = process_runs(dataframe)

os.makedirs("bin", exist_ok=True)
with open("bin/finetune_scores.pickle", "wb") as handle:
pickle.dump(full_scores, handle, protocol=pickle.HIGHEST_PROTOCOL)
Loading

0 comments on commit 33eba94

Please sign in to comment.