From 043c408dbaa731d3ccdceb11f238ae18bace1386 Mon Sep 17 00:00:00 2001
From: JaiDhyani <jaiwithani@gmail.com>
Date: Fri, 23 Feb 2024 03:12:58 -0800
Subject: [PATCH] Add calc_model_group_stats function to calculate useful stats
 for visualization

---
 src/delphi/eval/calc_model_group_stats.py | 56 +++++++++++++++++++++++
 1 file changed, 56 insertions(+)
 create mode 100644 src/delphi/eval/calc_model_group_stats.py

diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py
new file mode 100644
index 00000000..bdec6bf9
--- /dev/null
+++ b/src/delphi/eval/calc_model_group_stats.py
@@ -0,0 +1,56 @@
+from datasets import Dataset
+import numpy as np
+
+
+def calc_model_group_stats(
+    tokenized_corpus_dataset: list,
+    logprob_datasets: dict[str, list[list[float]]],
+    token_groups: dict[int, dict[str, bool]],
+    models: list[str],
+    token_labels: list[str],
+) -> dict[tuple[str, str], dict[str, float]]:
+    """
+    For each (model, token group) pair, calculate useful stats (for visualization)
+
+    args:
+    - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
+    - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
+    - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
+    - models: a list of model names, e.g. ["llama2", "gpt2", ...]
+    - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
+
+    returns: a dict of (model, token group) pairs to a dict of stats,
+        e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
+
+    Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
+    but it's better to be explicit
+
+    stats calculated: mean, median, min, max, 25th percentile, 75th percentile
+    """
+    model_group_stats = {}
+    for model in models:
+        group_logprobs = {}
+        print(f"Processing model {model}")
+        dataset = logprob_datasets[model]
+        for ix_doc_lp, document_lps in enumerate(dataset):
+            tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
+            for ix_token, token in enumerate(tokens):
+                if ix_token == 0:  # skip the first token, which isn't predicted
+                    continue
+                logprob = document_lps[ix_token]
+                for token_group_desc in token_labels:
+                    if token_groups[token][token_group_desc]:
+                        if token_group_desc not in group_logprobs:
+                            group_logprobs[token_group_desc] = []
+                        group_logprobs[token_group_desc].append(logprob)
+        for token_group_desc in token_labels:
+            if token_group_desc in group_logprobs:
+                model_group_stats[(model, token_group_desc)] = {
+                    "mean": np.mean(group_logprobs[token_group_desc]),
+                    "median": np.median(group_logprobs[token_group_desc]),
+                    "min": np.min(group_logprobs[token_group_desc]),
+                    "max": np.max(group_logprobs[token_group_desc]),
+                    "25th": np.percentile(group_logprobs[token_group_desc], 25),
+                    "75th": np.percentile(group_logprobs[token_group_desc], 75),
+                }
+    return model_group_stats