diff --git a/reference/README.md b/reference/README.md
new file mode 100644
index 0000000..494b350
--- /dev/null
+++ b/reference/README.md
@@ -0,0 +1,7 @@
+# Reference Implementations
+
+This directory contains reference implementations of experiments run in the FanOutQA paper. These implementations are
+provided "as-is", with no guarantees on usability or support provided. Note the code quality of these implementations
+may be lower than the library code, as these are earlier explorations run a) on a deadline and b) to develop the
+interfaces that were eventually exposed in this library.
+
diff --git a/reference/paper_benchmarks/.gitignore b/reference/paper_benchmarks/.gitignore
new file mode 100644
index 0000000..e9a9a46
--- /dev/null
+++ b/reference/paper_benchmarks/.gitignore
@@ -0,0 +1,161 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+**.DS_Store
+results/extra-*
+results/prompts-*
+results-old/
+.wikicache/
+tmp*/
+BLEURT-20/
+.llmcache/
diff --git a/reference/paper_benchmarks/README.md b/reference/paper_benchmarks/README.md
new file mode 100644
index 0000000..95eeb8c
--- /dev/null
+++ b/reference/paper_benchmarks/README.md
@@ -0,0 +1,28 @@
+# benchmarks
+
+make sure to do `python -m spacy download en_core_web_sm`
+
+closedbook: just the q and 0 shot prompt to output short text
+
+openbook: like closedbook but you can search wikipedia wowza
+
+wiki provided: chunk the known evidence 1024 characters at a time (preferring splitting at paragraph and sentence
+boundaries), retrieve as many passages fit into context window as possible, use bm25
+
+run:
+
+```shell
+WIKI_CACHE_DIR=/nlp/data/andrz/fanoutqa-bench/.wikicache python run_openbook.py -m mistral-chat -n 3
+for i in slurm/*; do sbatch $i; done
+```
+
+# eval
+
+install:
+
+```shell
+wget https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip
+unzip BLEURT-20.zip
+rm BLEURT-20.zip
+python -m spacy download en_core_web_sm
+```
\ No newline at end of file
diff --git a/reference/paper_benchmarks/bench/__init__.py b/reference/paper_benchmarks/bench/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/reference/paper_benchmarks/bench/engines.py b/reference/paper_benchmarks/bench/engines.py
new file mode 100644
index 0000000..64ae8af
--- /dev/null
+++ b/reference/paper_benchmarks/bench/engines.py
@@ -0,0 +1,111 @@
+import os
+
+import aiolimiter
+import torch
+from kani import AIFunction, ChatMessage
+from kani.engines.anthropic import AnthropicEngine
+from kani.engines.base import BaseEngine
+from kani.engines.huggingface.llama2 import LlamaEngine
+from kani.engines.openai import OpenAIEngine
+from kani.engines.openai.models import ChatCompletion
+
+# openai
+ft_org_id = "org-dxkPAvuWroRmYw7z0VhqMC5S"
+lab_org_id = "org-bgAXfs8WdU5942SLngg0OGpd"
+
+# hf
+common_hf_model_args = dict(
+ device_map="auto",
+ torch_dtype=torch.float16,
+)
+
+if mcd := os.getenv("MODEL_CACHE_DIR"):
+ common_hf_model_args["cache_dir"] = mcd
+
+
+def can_parallel(engine):
+ return isinstance(engine, OpenAIEngine)
+
+
+def get_engine(name: str, ctx_size=None) -> BaseEngine:
+ if name == "gpt-4":
+ return RatelimitedOpenAIEngine(
+ model="gpt-4-0613",
+ organization=lab_org_id,
+ max_context_size=ctx_size or 8192,
+ tpm_limiter=aiolimiter.AsyncLimiter(300000),
+ temperature=0,
+ )
+ elif name == "gpt-4-turbo":
+ return RatelimitedOpenAIEngine(
+ model="gpt-4-0125-preview",
+ organization=ft_org_id,
+ max_context_size=ctx_size or 128000,
+ tpm_limiter=aiolimiter.AsyncLimiter(400000),
+ temperature=0,
+ )
+ elif name == "gpt-3.5-turbo":
+ return RatelimitedOpenAIEngine(
+ model="gpt-3.5-turbo-1106",
+ organization=ft_org_id,
+ max_context_size=ctx_size or 16000,
+ tpm_limiter=aiolimiter.AsyncLimiter(2000000),
+ temperature=0,
+ )
+ elif name == "llama-chat":
+ return LlamaEngine(
+ "meta-llama/Llama-2-70b-chat-hf",
+ max_context_size=ctx_size or 4096,
+ model_load_kwargs=common_hf_model_args,
+ use_auth_token=True,
+ do_sample=False,
+ )
+ elif name == "mistral-chat":
+ return LlamaEngine(
+ "mistralai/Mistral-7B-Instruct-v0.2",
+ max_context_size=ctx_size or 32768,
+ model_load_kwargs=common_hf_model_args,
+ do_sample=False,
+ )
+ elif name == "mixtral":
+ return LlamaEngine(
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ max_context_size=ctx_size or 32768,
+ model_load_kwargs=common_hf_model_args,
+ do_sample=False,
+ )
+ elif name == "longllama":
+ return LlamaEngine(
+ "syzymon/long_llama_3b_instruct",
+ max_context_size=ctx_size or 512000,
+ tokenizer_kwargs={"trust_remote_code": True},
+ model_load_kwargs={"trust_remote_code": True, **common_hf_model_args},
+ do_sample=False,
+ )
+ elif name == "claude":
+ return AnthropicEngine(
+ model="claude-2.1",
+ max_context_size=ctx_size or 200000,
+ temperature=0,
+ )
+ else:
+ raise ValueError("Invalid model name")
+
+
+class RatelimitedOpenAIEngine(OpenAIEngine):
+ def __init__(
+ self, *args, rpm_limiter: aiolimiter.AsyncLimiter = None, tpm_limiter: aiolimiter.AsyncLimiter = None, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.rpm_limiter = rpm_limiter
+ self.tpm_limiter = tpm_limiter
+
+ async def predict(
+ self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
+ ) -> ChatCompletion:
+ if self.rpm_limiter:
+ await self.rpm_limiter.acquire()
+ if self.tpm_limiter:
+ n_toks = self.function_token_reserve(functions) + sum(self.message_len(m) for m in messages)
+ await self.tpm_limiter.acquire(n_toks)
+ return await super().predict(messages, functions, **hyperparams)
diff --git a/reference/paper_benchmarks/bench/runner.py b/reference/paper_benchmarks/bench/runner.py
new file mode 100644
index 0000000..294dcc4
--- /dev/null
+++ b/reference/paper_benchmarks/bench/runner.py
@@ -0,0 +1,198 @@
+import argparse
+import asyncio
+import json
+import logging
+from collections import namedtuple
+
+import tqdm
+from kani import Kani
+
+from bench.engines import can_parallel, get_engine
+from bench.utils import REPO_ROOT, load_questions
+from fanout.utils import batched
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "-m",
+ "--model",
+ choices=[
+ "gpt-4",
+ "gpt-4-turbo",
+ "gpt-3.5-turbo",
+ "llama-chat",
+ "mistral-chat",
+ "mixtral",
+ "longllama",
+ "claude",
+ ],
+ required=True,
+)
+parser.add_argument("-n", "--count", type=int, default=None)
+parser.add_argument("--ignore-config", type=bool, default=False)
+parser.add_argument("--no-cache-prompt", type=bool, default=False)
+parser.add_argument("--ctx-size", type=int, default=None)
+
+
+def argstokwargs(args):
+ return dict(
+ model_name=args.model,
+ n_questions=args.count,
+ ignore_config=args.ignore_config,
+ no_cache_prompt=args.no_cache_prompt,
+ ctx_size=args.ctx_size,
+ )
+
+
+log = logging.getLogger(__name__)
+
+TOKEN_RESERVE = 520 # 512 for gen + 8 for formatting etc
+BATCH_SIZE = 20
+
+GenerationResult = namedtuple("GenerationResult", "success content extra")
+
+
+class Runner:
+ def __init__(
+ self,
+ bench_setting: str,
+ model_name: str = None,
+ n_questions: int = None,
+ ignore_config=False,
+ no_cache_prompt=False,
+ ctx_size: int = None,
+ ):
+ if ctx_size is not None:
+ model_key = model_name
+ model_name = f"{model_name}-ctx-{ctx_size}"
+ else:
+ model_key = model_name
+ self.bench_setting = bench_setting
+ self.model_name = model_name
+ self.n_questions = n_questions
+ self.ignore_config = ignore_config
+ self.no_cache_prompt = no_cache_prompt
+ self.prompts = []
+ self.engine = get_engine(model_key, ctx_size=ctx_size)
+ self.questions = []
+ self.write_lock = asyncio.Lock()
+ self.results_file = open(REPO_ROOT / f"results/results-{self.bench_setting}-{self.model_name}.jsonl", "a")
+ self.extras_file = open(REPO_ROOT / f"results/extra-{self.bench_setting}-{self.model_name}.jsonl", "a")
+
+ def get_questions(self):
+ qs = load_questions()
+
+ # filter to only ids in questions
+ if not self.ignore_config:
+ cfg_fp = REPO_ROOT / f"config/{self.bench_setting}-{self.model_name}.txt"
+ if cfg_fp.exists():
+ with open(cfg_fp) as f:
+ ids = set(l.strip() for l in f if l.strip())
+ qs = [q for q in qs if q["id"] in ids]
+
+ if self.n_questions:
+ return qs[: self.n_questions]
+ return qs
+
+ def load(self, *args, **kwargs):
+ self.questions = self.get_questions()
+ self.prompts = self.get_prompts(self.questions)
+
+ def get_prompt(self, question, *args, **kwargs) -> str:
+ raise NotImplementedError
+
+ def get_prompts(self, questions, *args, **kwargs) -> list[str]:
+ # load existing prompts
+ prompt_fp = REPO_ROOT / f"results/prompts-{self.bench_setting}-{self.model_name}.json"
+ existing_prompts = {} # id to prompt
+ if prompt_fp.exists() and not self.no_cache_prompt:
+ with open(prompt_fp) as f:
+ prompts = json.load(f)
+ existing_prompts = {p["id"]: p["prompt"] for p in prompts}
+
+ # gen prompts that are missing
+ out = []
+ for q in tqdm.tqdm(questions):
+ if q["id"] in existing_prompts:
+ prompt = existing_prompts[q["id"]]
+ else:
+ prompt = self.get_prompt(q, *args, **kwargs)
+ existing_prompts[q["id"]] = prompt
+ out.append(prompt)
+
+ # save prompts, merge with existing
+ with open(prompt_fp, "w") as f:
+ data = [{"id": p_id, "prompt": p} for p_id, p in existing_prompts.items()]
+ json.dump(data, f)
+
+ return out
+
+ # def save(self, results: list[GenerationResult]):
+ # output = []
+ # for idx, gen in enumerate(results):
+ # q = self.questions[idx]
+ # data = {"id": q["id"], "answer": gen.content, "question": q, "prompt": self.prompts[idx]}
+ # if gen.extra:
+ # data["extra"] = gen.extra
+ # output.append(data)
+ #
+ # with open(REPO_ROOT / f"results/{self.bench_setting}-{self.model_name}.json", "w") as f:
+ # json.dump(output, f, indent=2)
+
+ def save_one(self, result: GenerationResult, idx: int):
+ q = self.questions[idx]
+ data = {"id": q["id"], "answer": result.content, "question": q}
+ self.results_file.write(json.dumps(data))
+ self.results_file.write("\n")
+ # extras
+ data = {
+ "id": q["id"],
+ "answer": result.content,
+ "question": q,
+ "prompt": self.prompts[idx],
+ "extra": result.extra,
+ }
+ self.extras_file.write(json.dumps(data))
+ self.extras_file.write("\n")
+
+ # runners
+ async def run_one(self, prompt, question, kani_cls=Kani, **kwargs) -> GenerationResult:
+ ai = kani_cls(self.engine, **kwargs)
+ try:
+ resp = await ai.chat_round_str(prompt)
+ return GenerationResult(success=True, content=resp, extra=None)
+ except Exception:
+ log.exception("Error getting response")
+ return GenerationResult(success=False, content=None, extra=None)
+
+ async def _run_one(self, *args, idx: int, **kwargs):
+ result = await self.run_one(*args, **kwargs)
+ async with self.write_lock:
+ self.save_one(result, idx)
+ return result
+
+ async def _run_series(self, *args, **kwargs):
+ results = []
+ for idx, prompt in tqdm.tqdm(enumerate(self.prompts), total=len(self.prompts)):
+ question = self.questions[idx]
+ result = await self._run_one(prompt, question, *args, idx=idx, **kwargs)
+ results.append(result)
+ return results
+
+ async def _run_parallel(self, *args, **kwargs):
+ results = []
+ for batch in tqdm.tqdm(batched(enumerate(self.prompts), BATCH_SIZE), total=len(self.prompts) // BATCH_SIZE + 1):
+ r = await asyncio.gather(
+ *(self._run_one(prompt, self.questions[idx], *args, idx=idx, **kwargs) for idx, prompt in batch)
+ )
+ results.extend(r)
+ return results
+
+ async def run(self, kani_cls=Kani, **kwargs) -> list[GenerationResult]:
+ if can_parallel(self.engine):
+ return await self._run_parallel(kani_cls=kani_cls, **kwargs)
+ return await self._run_series(kani_cls=kani_cls, **kwargs)
+
+ async def close(self):
+ await self.engine.close()
+ self.results_file.close()
+ self.extras_file.close()
diff --git a/reference/paper_benchmarks/bench/score.py b/reference/paper_benchmarks/bench/score.py
new file mode 100644
index 0000000..6b72e31
--- /dev/null
+++ b/reference/paper_benchmarks/bench/score.py
@@ -0,0 +1,50 @@
+import itertools
+import re
+from collections import namedtuple
+
+from fanout.norm import normalize
+
+AccuracyResult = namedtuple("AccuracyResult", "found score missing")
+
+
+def answer_in_text(reference, candidate: str) -> AccuracyResult:
+ """Is the answer present in the text?"""
+ if isinstance(reference, list):
+ missing = []
+ for a in reference:
+ result = answer_in_text(a, candidate)
+ missing.extend(result.missing)
+ n_found = len(reference) - len(missing)
+ return AccuracyResult(found=n_found == len(reference), score=n_found / len(reference), missing=missing)
+ elif isinstance(reference, dict):
+ missing = []
+ vals = itertools.chain(reference.keys(), reference.values())
+ for a in vals:
+ result = answer_in_text(a, candidate)
+ missing.extend(result.missing)
+ n_ref = len(reference) * 2
+ n_found = n_ref - len(missing) # kvs
+ return AccuracyResult(found=n_found == n_ref, score=n_found / n_ref, missing=missing)
+ else:
+ if isinstance(reference, bool):
+ reference = "yes" if reference else "no"
+ # primitive
+ norm_ans = normalize(reference)
+ norm_cand = normalize(candidate)
+ # ensure the answer is surrounded by word boundaries
+ if not re.search(rf"\b{re.escape(norm_ans)}\b", norm_cand):
+ return AccuracyResult(found=False, score=0, missing=[norm_ans])
+ return AccuracyResult(found=True, score=1, missing=[])
+
+
+def str_answer(ans) -> str:
+ """Ensure the answer is a string for string-based metrics like ROUGE. Don't normalize it otherwise."""
+ if isinstance(ans, list):
+ return "\n".join(map(str_answer, ans))
+ elif isinstance(ans, dict):
+ return "\n".join(f"{k} - {str_answer(v)}" for k, v in ans.items())
+ elif isinstance(ans, bool):
+ return "yes" if ans else "no"
+ elif ans is None:
+ return ""
+ return str(ans)
diff --git a/reference/paper_benchmarks/bench/utils.py b/reference/paper_benchmarks/bench/utils.py
new file mode 100644
index 0000000..5d10f32
--- /dev/null
+++ b/reference/paper_benchmarks/bench/utils.py
@@ -0,0 +1,20 @@
+import json
+from pathlib import Path
+
+REPO_ROOT = Path(__file__).parents[1]
+
+
+def load_questions(fp: Path = None):
+ if fp is None:
+ fp = REPO_ROOT / "fanout-final-test.json"
+ with open(fp) as f:
+ return json.load(f)
+
+
+def load_results(fp: Path):
+ """Load results from a jsonl file"""
+ results = []
+ with open(fp) as f:
+ for line in f:
+ results.append(json.loads(line))
+ return results
diff --git a/reference/paper_benchmarks/fanout/__init__.py b/reference/paper_benchmarks/fanout/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/reference/paper_benchmarks/fanout/config.py b/reference/paper_benchmarks/fanout/config.py
new file mode 100644
index 0000000..342195f
--- /dev/null
+++ b/reference/paper_benchmarks/fanout/config.py
@@ -0,0 +1,8 @@
+import os
+from pathlib import Path
+
+if cache_dir := os.getenv("WIKI_CACHE_DIR"):
+ CACHE_DIR = Path(cache_dir)
+else:
+ CACHE_DIR = Path(__file__).parent / "../../ccbqa-processing/.wikicache"
+CACHE_DIR.mkdir(exist_ok=True)
diff --git a/reference/paper_benchmarks/fanout/norm.py b/reference/paper_benchmarks/fanout/norm.py
new file mode 100644
index 0000000..98798a3
--- /dev/null
+++ b/reference/paper_benchmarks/fanout/norm.py
@@ -0,0 +1,79 @@
+import logging
+import re
+from decimal import Decimal
+
+import ftfy
+import spacy
+
+log = logging.getLogger(__name__)
+nlp = spacy.load("en_core_web_sm")
+
+
+def normalize(text, remove_stopwords=False):
+ """
+ - ftfy
+ - normalize numbers
+ - lemmatize
+ - remove stopwords (optional)
+ - remove punctuation
+ - remove redundant whitespace
+ """
+ text = str(text).lower()
+ text = ftfy.fix_text(text)
+ text = normalize_numbers(text)
+ text = lemmatize(text, remove_stopwords=remove_stopwords)
+ text = remove_punct(text)
+ text = normalize_whitespace(text)
+ return text
+
+
+def normalize_numbers(text: str, do_text_sub=False):
+ """Use regex to normalize numbers like 5.2 billion, and numbers with commas"""
+ # numbers with commas
+ comma_sub_text = re.sub(r"(\d+,)+\d+(\.\d+)?", lambda m: m[0].replace(",", ""), text)
+
+ if not do_text_sub:
+ return comma_sub_text
+
+ # numbers with text
+ def number_text_sub(match: re.Match):
+ n = Decimal(match[1]) # for precision
+ muls = match[2].strip()
+ for mul in muls.split():
+ match mul.lower():
+ case "thousand":
+ n *= 1_000
+ case "million":
+ n *= 1_000_000
+ case "billion":
+ n *= 1_000_000_000
+ case "trillion":
+ n *= 1_000_000_000_000
+ return str(n.normalize())
+
+ textual_number_sub_text = re.sub(
+ r"(\d+(?:\.\d+)?)((?:\s*(?:thousand|million|billion|trillion))+)",
+ number_text_sub,
+ comma_sub_text,
+ flags=re.IGNORECASE,
+ )
+
+ return textual_number_sub_text
+
+
+def lemmatize(text: str, remove_stopwords=False):
+ """Return a normalized string with each word replaced by its lemmatized version."""
+ doc = nlp(text)
+ if remove_stopwords:
+ return " ".join(tok.lemma_ for tok in doc if not tok.is_stop)
+ return " ".join(tok.lemma_ for tok in doc)
+
+
+def remove_punct(text: str):
+ """Remove all punctuation from the string."""
+ return re.sub(r"[,.?!:;]", "", text)
+
+
+def normalize_whitespace(text: str):
+ """Replace all whitespace with a single space."""
+ return re.sub(r"\s+", " ", text)
diff --git a/reference/paper_benchmarks/fanout/retrieval.py b/reference/paper_benchmarks/fanout/retrieval.py
new file mode 100644
index 0000000..c7462df
--- /dev/null
+++ b/reference/paper_benchmarks/fanout/retrieval.py
@@ -0,0 +1,75 @@
+from collections import namedtuple
+from typing import Iterable
+
+import numpy as np
+from rank_bm25 import BM25Plus
+
+from fanout.norm import normalize
+from fanout.wiki import get_page_markdown
+
+RetrievalResult = namedtuple("RetrievalResult", "title content")
+
+
+class Corpus:
+ """A corpus of wiki docs. Indexes the docs on creation, normalizing the text beforehand with lemmatization."""
+
+ def __init__(self, documents: list[dict], doc_len: int):
+ """
+ :param documents: The list of evidences to index
+ :param doc_len: The maximum length, in characters, of each chunk
+ """
+ self.documents = []
+ normalized_corpus = []
+ for doc in documents:
+ title = doc["title"]
+ content = get_page_markdown(doc["pageid"])
+ for chunk in chunk_text(content, max_chunk_size=doc_len):
+ self.documents.append(RetrievalResult(title, chunk))
+ normalized_corpus.append(self.tokenize(chunk))
+
+ self.index = BM25Plus(normalized_corpus)
+
+ def tokenize(self, text: str):
+ """Split the text into words, lemmatize, remove stopwords."""
+ return normalize(text).split(" ")
+
+ def best(self, q) -> Iterable[RetrievalResult]:
+ """Yield the best matching fragments to the given query."""
+ tok_q = self.tokenize(q)
+ scores = self.index.get_scores(tok_q)
+ idxs = np.argsort(scores)[::-1]
+ for idx in idxs:
+ yield self.documents[idx]
+
+
+def chunk_text(text, max_chunk_size=1024, chunk_on=("\n\n", "\n", ". ", ", ", " "), chunker_i=0):
+ """
+ Recursively chunks *text* into a list of str, with each element no longer than *max_chunk_size*.
+ Prefers splitting on the elements of *chunk_on*, in order.
+ """
+
+ if len(text) <= max_chunk_size: # the chunk is small enough
+ return [text]
+ if chunker_i >= len(chunk_on): # we have no more preferred chunk_on characters
+ # optimization: instead of merging a thousand characters, just use list slicing
+ return [text[:max_chunk_size], *chunk_text(text[max_chunk_size:], max_chunk_size, chunk_on, chunker_i + 1)]
+
+ # split on the current character
+ chunks = []
+ split_char = chunk_on[chunker_i]
+ for chunk in text.split(split_char):
+ chunk = f"{chunk}{split_char}"
+ if len(chunk) > max_chunk_size: # this chunk needs to be split more, recurse
+ chunks.extend(chunk_text(chunk, max_chunk_size, chunk_on, chunker_i + 1))
+ elif chunks and len(chunk) + len(chunks[-1]) <= max_chunk_size: # this chunk can be merged
+ chunks[-1] += chunk
+ else:
+ chunks.append(chunk)
+
+ # if the last chunk is just the split_char, yeet it
+ if chunks[-1] == split_char:
+ chunks.pop()
+
+ # remove extra split_char from last chunk
+ chunks[-1] = chunks[-1][: -len(split_char)]
+ return chunks
diff --git a/reference/paper_benchmarks/fanout/utils.py b/reference/paper_benchmarks/fanout/utils.py
new file mode 100644
index 0000000..231b84d
--- /dev/null
+++ b/reference/paper_benchmarks/fanout/utils.py
@@ -0,0 +1,47 @@
+from itertools import islice
+
+from markdownify import MarkdownConverter
+
+
+def batched(iterable, n):
+ # batched('ABCDEFG', 3) --> ABC DEF G
+ if n < 1:
+ raise ValueError("n must be at least one")
+ it = iter(iterable)
+ while batch := tuple(islice(it, n)):
+ yield batch
+
+
+# markdownification
+def yeet(*_):
+ return ""
+
+
+def is_valid_url(x):
+ if not x:
+ return False
+ return not x.startswith("data:")
+
+
+class MDConverter(MarkdownConverter):
+ def convert_img(self, el, text, convert_as_inline):
+ alt = el.attrs.get("alt", None) or ""
+ return f"![{alt}](image)"
+
+ def convert_a(self, el, text, convert_as_inline):
+ return text
+
+ # noinspection PyMethodMayBeStatic,PyUnusedLocal
+ def convert_div(self, el, text, convert_as_inline):
+ content = text.strip()
+ if not content:
+ return ""
+ return f"{content}\n"
+
+ # sometimes these appear inline and are just annoying
+ convert_script = yeet
+ convert_style = yeet
+
+
+def markdownify(html: str):
+ return MDConverter(heading_style="atx").convert(html)
diff --git a/reference/paper_benchmarks/fanout/wiki.py b/reference/paper_benchmarks/fanout/wiki.py
new file mode 100644
index 0000000..a5ac1de
--- /dev/null
+++ b/reference/paper_benchmarks/fanout/wiki.py
@@ -0,0 +1,134 @@
+"""Utils for working with Wikipedia"""
+
+import datetime
+import json
+import logging
+import re
+
+import mediawiki
+from bs4 import BeautifulSoup
+from mediawiki import MediaWiki, MediaWikiPage
+
+from fanout.config import CACHE_DIR
+from fanout.norm import normalize
+from fanout.utils import markdownify
+
+USER_AGENT = "fanoutqa/bench (andrz@seas.upenn.edu) pymediawiki/0.7.4"
+DATASET_EPOCH = datetime.datetime(year=2023, month=11, day=20, tzinfo=datetime.timezone.utc)
+
+log = logging.getLogger(__name__)
+wikipedia = MediaWiki(user_agent=USER_AGENT)
+
+
+# ==== classes ====
+class DatedPage(MediaWikiPage):
+ """To query contents as of the right date, we override some of the request params"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._dated_html = False
+
+ def get_dated_html(self):
+ """Get the page content as of the dataset epoch."""
+ cache_filename = CACHE_DIR / f"{self.pageid}-dated.html"
+ if self._dated_html is False:
+ if cache_filename.exists():
+ self._dated_html = cache_filename.read_text()
+ else:
+ self._dated_html = ""
+ query_params = {
+ "prop": "revisions",
+ "rvprop": "content",
+ "rvlimit": 1,
+ "rvparse": "",
+ "titles": self.title,
+ # added here
+ "rvstart": DATASET_EPOCH.isoformat(),
+ }
+ request = self.mediawiki.wiki_request(query_params)
+ page = request["query"]["pages"][self.pageid]
+ try:
+ self._dated_html = page["revisions"][0]["*"]
+ except KeyError:
+ log.warning(f"Could not find dated revision of {self.title} - maybe the page did not exist yet?")
+ pass
+ cache_filename.write_text(self._dated_html)
+ return self._dated_html
+
+ def get_dated_revid(self):
+ query_params = {
+ "prop": "revisions",
+ "rvprop": "ids",
+ "rvlimit": 1,
+ "rvparse": "",
+ "titles": self.title,
+ # added here
+ "rvstart": DATASET_EPOCH.isoformat(),
+ }
+ request = self.mediawiki.wiki_request(query_params)
+ page = request["query"]["pages"][self.pageid]
+ return page["revisions"][0]["revid"]
+
+ def get_backlinks(self):
+ """Cached version of page.backlinks"""
+ cache_filename = CACHE_DIR / f"{self.pageid}-backlinks.json"
+ if cache_filename.exists():
+ with open(cache_filename) as f:
+ return json.load(f)
+ backlinks = self.backlinks
+ with open(cache_filename, "w") as f:
+ json.dump(backlinks, f)
+ return backlinks
+
+
+class WikiError(Exception):
+ pass
+
+
+# ==== main ====
+def get_wikipedia_page_by_id(pageid):
+ try:
+ return DatedPage(wikipedia, pageid=pageid, redirect=True, preload=False)
+ except (mediawiki.PageError, mediawiki.DisambiguationError) as e:
+ log.exception("Got PageError:")
+ raise WikiError(repr(e)) from e
+
+
+def get_page_all_text(pageid: int, norm=True, remove_citations_and_notes=False) -> str:
+ """
+ Render the page HTML **as of the dataset epoch** and retrieve all visible text (including tables and infoboxes),
+ without markup.
+ """
+ if norm:
+ cache_filename = CACHE_DIR / f"{pageid}-dated-norm.txt"
+ else:
+ cache_filename = CACHE_DIR / f"{pageid}-dated.txt"
+
+ if cache_filename.exists():
+ text = cache_filename.read_text()
+ else:
+ page = get_wikipedia_page_by_id(pageid)
+ html = page.get_dated_html()
+ soup = BeautifulSoup(html, "html.parser")
+ text = soup.get_text(separator=" ", strip=True)
+ if norm:
+ text = normalize(text)
+ cache_filename.write_text(text)
+
+ if remove_citations_and_notes:
+ text = re.sub(r"\s\[ (((note\s+)?\d+)|edit) ]\s", " ", text)
+ return text
+
+
+def get_page_markdown(pageid: int):
+ """Get the page content in markdown, including tables and infoboxes, appropriate for displaying to an LLM"""
+ cache_filename = CACHE_DIR / f"{pageid}-dated.md"
+
+ if cache_filename.exists():
+ return cache_filename.read_text()
+
+ page = get_wikipedia_page_by_id(pageid)
+ html = page.get_dated_html()
+ text = markdownify(html)
+ cache_filename.write_text(text)
+ return text
diff --git a/reference/paper_benchmarks/gen_slurm.py b/reference/paper_benchmarks/gen_slurm.py
new file mode 100644
index 0000000..7ec042d
--- /dev/null
+++ b/reference/paper_benchmarks/gen_slurm.py
@@ -0,0 +1,76 @@
+SETTINGS = (
+ "openbook",
+ "closedbook",
+ "wiki_provided",
+)
+MODELS = {
+ "gpt-4": "small",
+ "gpt-4-turbo": "small",
+ "gpt-3.5-turbo": "small",
+ "llama-chat": "large",
+ "mistral-chat": "med",
+ "mixtral": "large",
+ # "longllama": "large",
+ "claude": "small",
+}
+TEMPLATE = """\
+#!/bin/bash
+#
+#SBATCH --partition=p_nlp
+#SBATCH --job-name=foqa-{model}{fn_extra}-{setting}
+#SBATCH --output=/nlp/data/andrz/logs/%j.%x.log
+#SBATCH --error=/nlp/data/andrz/logs/%j.%x.log
+#SBATCH --time=7-0
+#SBATCH -c {cpus}
+#SBATCH --mem={mem}
+#SBATCH --gpus={gpus}
+{gpuconstraint}
+export WIKI_CACHE_DIR=/nlp/data/andrz/fanoutqa-bench/.wikicache
+srun python run_{setting}.py -m {model}{extra}
+"""
+
+
+def write_slurm_file(setting, model, extra="", fn_extra=""):
+ size = MODELS[model]
+ if size == "small":
+ cpus = 4
+ mem = "32G"
+ gpus = 0
+ gpuconstraint = ""
+ elif size == "med":
+ cpus = 4
+ mem = "64G"
+ gpus = 1
+ gpuconstraint = "#SBATCH --constraint=48GBgpu"
+ else:
+ cpus = 16
+ mem = "128G"
+ gpus = 3
+ gpuconstraint = "#SBATCH --constraint=48GBgpu"
+
+ content = TEMPLATE.format(
+ setting=setting,
+ model=model,
+ cpus=cpus,
+ mem=mem,
+ gpus=gpus,
+ gpuconstraint=gpuconstraint,
+ extra=extra,
+ fn_extra=fn_extra,
+ )
+ with open(f"slurm/{model}{fn_extra}-{setting}.sh", "w") as f:
+ f.write(content)
+
+
+def main():
+ for s in SETTINGS:
+ for m in MODELS:
+ write_slurm_file(s, m)
+
+ for m in MODELS:
+ write_slurm_file("openbook", m, extra=" --ctx-size 4096", fn_extra="-short")
+ write_slurm_file("wiki_provided", m, extra=" --ctx-size 4096", fn_extra="-short")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reference/paper_benchmarks/get_10_questions.py b/reference/paper_benchmarks/get_10_questions.py
new file mode 100644
index 0000000..080b163
--- /dev/null
+++ b/reference/paper_benchmarks/get_10_questions.py
@@ -0,0 +1,15 @@
+"""Output 10 questions for the human eval."""
+
+import random
+
+from bench.utils import REPO_ROOT, load_questions
+
+
+def main():
+ qs = load_questions(REPO_ROOT / "fanout-final-dev.json")
+ for i, q in enumerate(random.sample(qs, 10)):
+ print(f"**Question {i + 1} ({q['id']})**: {q['question']}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reference/paper_benchmarks/requirements.txt b/reference/paper_benchmarks/requirements.txt
new file mode 100644
index 0000000..fef2470
--- /dev/null
+++ b/reference/paper_benchmarks/requirements.txt
@@ -0,0 +1,21 @@
+accelerate~=0.26.1
+aiolimiter~=1.1.0
+anthropic~=0.14.0
+beautifulsoup4~=4.12.3
+ftfy~=6.1.3
+kani~=0.7.2
+markdownify~=0.11.6
+openai~=1.10.0
+protobuf~=4.25.2
+pymediawiki~=0.7.4
+rank-bm25~=0.2.2
+sentencepiece~=0.1.99
+spacy~=3.7.2
+tiktoken~=0.5.2
+torch~=2.2.0
+tqdm~=4.66.1
+transformers~=4.37.2
+
+# scoring
+rouge-score~=0.1.2
+git+https://github.com/google-research/bleurt.git@master
diff --git a/reference/paper_benchmarks/run_closedbook.py b/reference/paper_benchmarks/run_closedbook.py
new file mode 100644
index 0000000..094a428
--- /dev/null
+++ b/reference/paper_benchmarks/run_closedbook.py
@@ -0,0 +1,27 @@
+import asyncio
+import logging
+
+from bench.runner import Runner, argstokwargs, parser
+
+
+class ClosedBookRunner(Runner):
+ def get_prompt(self, question, *args, **kwargs) -> str:
+ question = question["question"]
+ prompt = (
+ "Answer the following question, and output only your answer. If the answer is a list, output one on"
+ f" each line. Current date: 11-20-2023.\n\n[Question]: {question}"
+ )
+ return prompt
+
+
+async def main():
+ args = parser.parse_args()
+ runner = ClosedBookRunner(bench_setting="closedbook", **argstokwargs(args))
+ runner.load()
+ await runner.run()
+ await runner.close()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ asyncio.run(main())
diff --git a/reference/paper_benchmarks/run_openbook.py b/reference/paper_benchmarks/run_openbook.py
new file mode 100644
index 0000000..bee32d4
--- /dev/null
+++ b/reference/paper_benchmarks/run_openbook.py
@@ -0,0 +1,153 @@
+import asyncio
+import logging
+import re
+
+import mediawiki
+from kani import ChatMessage, ChatRole, Kani, ToolCall, ai_function
+from kani.engines.base import BaseEngine, Completion
+
+from bench.runner import GenerationResult, Runner, argstokwargs, parser
+from fanout.retrieval import Corpus
+from fanout.wiki import DatedPage, wikipedia
+
+log = logging.getLogger("openbook")
+
+
+# ===== runner =====
+class OpenBookRunner(Runner):
+ def get_prompt(self, question, *args, **kwargs) -> str:
+ question = question["question"]
+ prompt = (
+ "You have the ability to search Wikipedia for information. To do so, output a message in the format"
+ " {YOUR_SEARCH_QUERY} (e.g. `List of states and territories of the United"
+ " States`).\nAnswer the following question, and output only your answer or a search, but not"
+ " both. If the answer is a list, output one on each line. Current date: 11-20-2023.\n\n[Question]:"
+ f" {question}"
+ )
+ return prompt
+
+ async def run_one(self, prompt, question, kani_cls=Kani, **kwargs) -> GenerationResult:
+ try:
+ return await asyncio.wait_for(self.run_impl(prompt, kani_cls, question=question, **kwargs), timeout=600)
+ except Exception:
+ log.exception("Error getting response")
+ return GenerationResult(success=False, content=None, extra=None)
+
+ async def run_impl(self, prompt, kani_cls=Kani, **kwargs):
+ ai = kani_cls(self.engine, **kwargs)
+ content = None
+ msgs = []
+ async for msg in ai.full_round(prompt):
+ msgs.append(repr(msg))
+ content = msg.content or content
+ if msg.role == ChatRole.ASSISTANT:
+ print(msg.content)
+ return GenerationResult(success=True, content=content, extra=msgs)
+
+
+class OpenAIOpenBookRunner(OpenBookRunner):
+ def get_prompt(self, question, *args, **kwargs) -> str:
+ question = question["question"]
+ if self.model_name.startswith("gpt-3.5-turbo"):
+ prompt = (
+ "Answer the following question, and output only your answer. You may search before outputting your"
+ " answer. If the answer is a list, output one on each line. Current date: 11-20-2023.\n\n[Question]:"
+ f" {question}"
+ )
+ else:
+ prompt = (
+ "Answer the following question, and output only a function call or your answer. If the answer is a"
+ f" list, output one on each line. Current date: 11-20-2023.\n\n[Question]: {question}"
+ )
+ return prompt
+
+
+# ==== kani ====
+class SearchEngine(BaseEngine):
+ def __init__(self, engine: BaseEngine, strict=False):
+ self.engine = engine
+ self.strict = strict
+ self.max_context_size = engine.max_context_size
+ self.token_reserve = engine.token_reserve
+
+ def message_len(self, message):
+ if message.role == ChatRole.FUNCTION:
+ message = message.copy_with(role=ChatRole.USER)
+ return self.engine.message_len(message)
+
+ async def predict(self, messages, functions=None, **kwargs):
+ translated_messages = []
+ for m in messages:
+ if m.role == ChatRole.FUNCTION:
+ translated_messages.append(m.copy_with(role=ChatRole.USER))
+ else:
+ translated_messages.append(m)
+ resp = await self.engine.predict(translated_messages, functions, **kwargs)
+ # search for a pair
+ content = resp.message.text
+ tool_calls = None
+ if self.strict:
+ match = re.match(r"(.+?)", content)
+ else:
+ match = re.search(r"(.+?)", content)
+ if match:
+ content = content[: match.end()]
+ tool_calls = [ToolCall.from_function("search", query=match[1])]
+ return Completion(message=ChatMessage.assistant(content, tool_calls=tool_calls))
+
+ def function_token_reserve(self, *args, **kwargs):
+ return 0 # wahoo we hardcode the prompt in the user message
+
+ async def close(self):
+ return await self.engine.close()
+
+
+class WikipediaKani(Kani):
+ def __init__(self, *args, question: str, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.question = question
+ self.max_search_tokens = self.engine.max_context_size // 2
+
+ @ai_function()
+ def search(self, query: str):
+ """Search Wikipedia for an article with the given title, and get its content. If no such article is found, return similar article names."""
+ try:
+ page = DatedPage(wikipedia, title=query, redirect=True, preload=False)
+ except mediawiki.PageError:
+ similar = wikipedia.search(query)
+ similar_searches = "\n".join(f"{title}" for title in similar)
+ return f"No page with that title exists. Try one of the similar searches:\n{similar_searches}"
+ except mediawiki.DisambiguationError as e:
+ similar_searches = "\n".join(f"{title}" for title in e.options)
+ return f"That may refer to multiple pages. Select one of these pages:\n{similar_searches}"
+ else:
+ corpus = Corpus([{"title": page.title, "pageid": page.pageid}], doc_len=1024)
+ # retrieve as many fragments as fit in the context window
+ retrieved_docs = []
+ prompt = f"\n{page.title}\n{{}}"
+ for doc in corpus.best(self.question):
+ formatted = f"\n{doc.content}\n\n"
+ content = prompt.format("".join(retrieved_docs) + formatted)
+ doc_len = self.engine.message_len(ChatMessage.user(content))
+ if doc_len > self.max_search_tokens:
+ break
+ retrieved_docs.append(formatted)
+ return prompt.format("".join(retrieved_docs))
+
+
+# ==== main ====
+async def main():
+ args = parser.parse_args()
+ if args.model.startswith("gpt"):
+ runner = OpenAIOpenBookRunner(bench_setting="openbook", **argstokwargs(args))
+ else:
+ runner = OpenBookRunner(bench_setting="openbook", **argstokwargs(args))
+ runner.engine = SearchEngine(runner.engine, strict=True)
+ runner.load()
+ await runner.run(kani_cls=WikipediaKani)
+ await runner.close()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ asyncio.run(main())
diff --git a/reference/paper_benchmarks/run_wiki_provided.py b/reference/paper_benchmarks/run_wiki_provided.py
new file mode 100644
index 0000000..5cf0113
--- /dev/null
+++ b/reference/paper_benchmarks/run_wiki_provided.py
@@ -0,0 +1,56 @@
+import asyncio
+import logging
+
+from kani import ChatMessage
+
+from bench.runner import Runner, TOKEN_RESERVE, argstokwargs, parser
+from fanout.retrieval import Corpus
+
+
+class WikiProvidedRunner(Runner):
+ def load(self, max_prompt_tokens):
+ self.questions = self.get_questions()
+ self.prompts = self.get_prompts(self.questions, max_prompt_tokens)
+
+ # noinspection PyMethodOverriding
+ def get_prompt(self, question, max_prompt_tokens) -> str:
+ question_text = question["question"]
+
+ # index all of the documents, splitting by paragraph/sentence to a max of 1024 characters
+ corpus = Corpus(question["necessary_evidence"], doc_len=1024)
+
+ # build the initial prompt
+ prompt = (
+ "*** BEGIN DATA ***\n\n{}\n*** END DATA ***\n\nAnswer the following question based on the documents"
+ " above, and output only your answer. If the answer is a list, output one on each line. Current date:"
+ f" 11-20-2023.\n\n[Question]: {question_text}"
+ )
+
+ # retrieve as many fragments as fit in the context window
+ # format: \n{title}\n{content}\n
+ retrieved_docs = []
+ for doc in corpus.best(question_text):
+ formatted = f"\n{doc.title}\n{doc.content}\n\n"
+ content = prompt.format("".join(retrieved_docs) + formatted)
+ doc_len = self.engine.message_len(ChatMessage.user(content))
+ if doc_len > max_prompt_tokens:
+ break
+ retrieved_docs.append(formatted)
+ prompt = prompt.format("".join(retrieved_docs))
+ return prompt
+
+
+async def main():
+ args = parser.parse_args()
+ runner = WikiProvidedRunner(bench_setting="wiki-provided", **argstokwargs(args))
+ max_prompt_tokens = runner.engine.max_context_size - TOKEN_RESERVE
+ print(f"Building prompts for input length {max_prompt_tokens}...")
+ runner.load(max_prompt_tokens)
+ print("Generating...")
+ await runner.run()
+ await runner.close()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ asyncio.run(main())
diff --git a/reference/paper_benchmarks/score.py b/reference/paper_benchmarks/score.py
new file mode 100644
index 0000000..264548c
--- /dev/null
+++ b/reference/paper_benchmarks/score.py
@@ -0,0 +1,245 @@
+import asyncio
+import hashlib
+import json
+import re
+from pathlib import Path
+
+import rouge_score.scoring
+import tqdm
+from bleurt.score import BleurtScorer
+from kani import Kani
+from kani.engines.openai import OpenAIEngine
+from rouge_score.rouge_scorer import RougeScorer
+
+from bench.score import answer_in_text, str_answer
+from bench.utils import REPO_ROOT, load_questions, load_results
+from fanout.utils import batched
+
+rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
+# ===== llm =====
+LLM_CACHE_DIR = REPO_ROOT / ".llmcache"
+LLM_CACHE_DIR.mkdir(exist_ok=True)
+engine = OpenAIEngine(model="gpt-4-0613", temperature=0, seed=31415)
+factuality_system = "You are comparing a submitted answer to an expert answer on a given question."
+
+
+def factuality_prompt_old(question, reference, answer):
+ return (
+ f"[BEGIN DATA]\n************\n[Question]: {question}\n************\n[Expert]:"
+ f" {reference}\n************\n[Submission]: {answer}\n************\n[END DATA]\n\nCompare the factual content"
+ " of the submitted answer with the expert answer. Ignore any differences in style, grammar, or"
+ " punctuation.\nThe submitted answer may either be a subset or superset of the expert answer, or it may"
+ " conflict with it. Determine which case applies. First, write out in a step by step manner your reasoning"
+ " about the factual content to be sure that your conclusion is correct. Avoid simply stating the correct"
+ ' answers at the outset. Then print only the single character "A", "B", "C", "D", or "E" (without quotes'
+ " or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the letter"
+ " again by itself on a new line.\n(A) The submitted answer is a subset of the expert answer and is fully"
+ " consistent with it.\n(B) The submitted answer is a superset of the expert answer and is fully consistent"
+ " with it.\n(C) The submitted answer contains all the same details as the expert answer.\n(D) There is a"
+ " disagreement between the submitted answer and the expert answer.\n(E) The answers differ, but these"
+ " differences don't matter from the perspective of factuality."
+ )
+
+
+def factuality_prompt(question, reference, answer):
+ return (
+ f"[BEGIN DATA]\n************\n[Question]: {question}\n************\n[Expert]:"
+ f" {reference}\n************\n[Submission]: {answer}\n************\n[END DATA]\n\nCompare the factual content"
+ " of the submitted answer with the expert answer. Ignore any differences in style, grammar, or"
+ " punctuation.\nThe submitted answer may either be a subset or superset of the expert answer, or it may"
+ " conflict with it. Determine which case applies. First, write out in a step by step manner your reasoning"
+ " about the factual content to be sure that your conclusion is correct. Avoid simply stating the correct"
+ ' answers at the outset. Then print only the single character "A", "B", "C", "D", "E", or "F" (without quotes'
+ " or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the letter"
+ " again by itself on a new line.\n(A) The submitted answer is a subset of the expert answer and is fully"
+ " consistent with it.\n(B) The submitted answer is a superset of the expert answer and is fully consistent"
+ " with it.\n(C) The submitted answer contains all the same details as the expert answer.\n(D) There is a"
+ " disagreement between the submitted answer and the expert answer.\n(E) The answers differ, but these"
+ " differences don't matter from the perspective of factuality.\n(F) The submitted answer does not answer the"
+ " question or is otherwise invalid."
+ )
+
+
+class Scorer:
+ def __init__(self, results, key, questions=None, only_score_answered=False):
+ self.key = key
+ self.only_score_answered = only_score_answered
+ self.questions = questions or load_questions()
+ self.questions_by_id = {q["id"]: q for q in self.questions}
+ self.results = results
+ self.results_by_id = {r["id"]: r for r in self.results}
+ self.rouge = RougeScorer(rouge_types, use_stemmer=True)
+ self.bleurt = BleurtScorer("BLEURT-20")
+ # number of trials to eval
+ if self.only_score_answered:
+ self.eval_len = len(self.results)
+ else:
+ self.eval_len = len(self.questions)
+
+ @classmethod
+ def from_fp(cls, fp: Path, questions):
+ results = load_results(fp)
+ return cls(results=results, key=fp.stem.removeprefix("results-"), questions=questions)
+
+ async def score(self):
+ acc = self.score_accuracy()
+ rouge = self.score_rouge()
+ bleurt_ = self.score_bleurt()
+ gptscore = await self.score_gpt()
+ data = {"acc": acc, "rouge": rouge, "bleurt": bleurt_, "gpt": gptscore}
+ with open(REPO_ROOT / f"results/score-{self.key}.json", "w") as f:
+ json.dump(data, f, indent=2)
+
+ def get_qa_pairs(self):
+ if self.only_score_answered:
+ for a in tqdm.tqdm(self.results):
+ q = self.questions_by_id.get(a["id"])
+ yield q, a
+ else:
+ for q in tqdm.tqdm(self.questions):
+ a = self.results_by_id.get(q["id"])
+ if a is None:
+ yield None, None
+ yield q, a
+
+ def score_accuracy(self):
+ accs = []
+ n_perfect = 0
+ for q, a in self.get_qa_pairs():
+ if a is None:
+ accs.append(0)
+ continue
+ result = answer_in_text(q["answer"], a["answer"])
+ accs.append(result.score)
+ if result.found:
+ n_perfect += 1
+
+ assert len(accs) == self.eval_len
+ avg_acc = sum(accs) / self.eval_len
+ pct_perfect = n_perfect / self.eval_len
+ print(f"AVG ACC: {avg_acc}")
+ print(f"PCT PFT: {pct_perfect}")
+ return {"acc": avg_acc, "perfect": pct_perfect}
+
+ def score_rouge(self):
+ scores = {t: [] for t in rouge_types}
+ for q, a in self.get_qa_pairs():
+ if a is None:
+ for score in scores.values():
+ score.append(rouge_score.scoring.Score(0, 0, 0))
+ continue
+ results = self.rouge.score(str_answer(q["answer"]), str_answer(a["answer"]))
+ for k, v in results.items():
+ scores[k].append(v)
+
+ assert all(len(v) == self.eval_len for v in scores.values())
+ print("=== ROUGE ===")
+ out = {}
+ for k, v in scores.items():
+ print(f"--- {k} ---")
+ avg_precision = sum(s.precision for s in v) / self.eval_len
+ avg_recall = sum(s.recall for s in v) / self.eval_len
+ avg_fscore = sum(s.fmeasure for s in v) / self.eval_len
+ print(f"precision: {avg_precision}")
+ print(f"recall: {avg_recall}")
+ print(f"fscore: {avg_fscore}")
+ out[k] = {"precision": avg_precision, "recall": avg_recall, "fscore": avg_fscore}
+ print()
+ return out
+
+ def score_bleurt(self):
+ references = []
+ candidates = []
+ for q, a in self.get_qa_pairs():
+ if a is None:
+ candidates.append("")
+ else:
+ candidates.append(str_answer(a["answer"]))
+ references.append(str_answer(q["answer"]))
+
+ scores = self.bleurt.score(references=references, candidates=candidates)
+ assert len(scores) == self.eval_len
+ avg_score = sum(scores) / self.eval_len
+ print(f"BLEURT: {avg_score}")
+ return avg_score
+
+ async def score_gpt(self):
+ accs = []
+
+ for pairs in batched(self.get_qa_pairs(), 20):
+ # eval 20 qs at a time
+ coros = []
+ for q, a in pairs:
+ if a is None:
+ accs.append(0)
+ continue
+ # sometimes we have fun neural text degeneration, just cut it off
+ ans = a["answer"][:4000]
+ coro = self.get_llm_factuality(q, ans)
+ coros.append(coro)
+
+ # and score their answers
+ # B, C, E = full score, anything else = 0
+ answers = await asyncio.gather(*coros)
+ for result in answers:
+ mc = result.strip()[-1].lower()
+ if mc in "bce":
+ accs.append(1)
+ else:
+ accs.append(0)
+
+ assert len(accs) == self.eval_len
+ avg_acc = sum(accs) / self.eval_len
+ print(f"GPT ACC: {avg_acc}")
+ return avg_acc
+
+ async def get_llm_factuality(self, question, answer):
+ # cache
+ ans_hash = hashlib.sha256(answer.encode()).hexdigest()[:8]
+ cache_filename = LLM_CACHE_DIR / f"factual-{self.key}-{question['id']}-{ans_hash}.txt"
+ if cache_filename.exists():
+ return cache_filename.read_text()
+
+ # ask the LLM if it is subjective
+ prompt = factuality_prompt(question["question"], str_answer(question["answer"]), answer)
+ ai = Kani(engine, system_prompt=factuality_system)
+ resp = await ai.chat_round_str(prompt)
+ cache_filename.write_text(resp)
+ return resp
+
+
+async def score_human():
+ questions = load_questions(REPO_ROOT / "fanout-final-test-answers.json")
+
+ # read the human responses and {id, answer} them
+ results = []
+ with open(REPO_ROOT / "human_responses.txt") as f:
+ data = f.read()
+ for segment in data.split("###SEP "):
+ segment = segment.strip()
+ if not segment:
+ continue
+ # fix some weird tokenization stuff in the human responses
+ segment = segment.replace("https://en.wikipedia.org/wiki/", "").replace("_", " ")
+
+ id_, content = segment.split("\n", 1)
+ results.append({"id": id_, "answer": content})
+
+ scorer = Scorer(results, key="human", questions=questions, only_score_answered=True)
+ await scorer.score()
+
+
+async def main(fps=None):
+ if fps is None:
+ fps = (REPO_ROOT / "results").glob("results-*.jsonl")
+ questions = load_questions(REPO_ROOT / "fanout-final-test-answers.json")
+ for result_path in fps:
+ print(result_path.stem)
+ scorer = Scorer.from_fp(result_path, questions)
+ await scorer.score()
+
+
+if __name__ == "__main__":
+ # fps = [Path(a) for a in sys.argv[1:]]
+ # asyncio.run(main(fps))
+ asyncio.run(score_human())
diff --git a/reference/paper_benchmarks/scores_to_csv.py b/reference/paper_benchmarks/scores_to_csv.py
new file mode 100644
index 0000000..7a9d414
--- /dev/null
+++ b/reference/paper_benchmarks/scores_to_csv.py
@@ -0,0 +1,35 @@
+"""Read all the scores and output them as CSV so I don't have to type them all."""
+
+import json
+
+from bench.utils import REPO_ROOT
+
+settings = ["closedbook", "openbook", "wiki-provided"]
+models = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "llama-chat", "mistral-chat", "mixtral", "claude"]
+
+
+def print_one(fp):
+ with open(fp) as f:
+ scores = json.load(f)
+ acc = scores["acc"]["acc"]
+ perf = scores["acc"]["perfect"]
+ r1p = scores["rouge"]["rouge1"]["precision"]
+ r1r = scores["rouge"]["rouge1"]["recall"]
+ r1f = scores["rouge"]["rouge1"]["fscore"]
+ r2p = scores["rouge"]["rouge2"]["precision"]
+ r2r = scores["rouge"]["rouge2"]["recall"]
+ r2f = scores["rouge"]["rouge2"]["fscore"]
+ rLp = scores["rouge"]["rougeL"]["precision"]
+ rLr = scores["rouge"]["rougeL"]["recall"]
+ rLf = scores["rouge"]["rougeL"]["fscore"]
+ bleurt = scores["bleurt"]
+ gptscore = scores["gpt"]
+ print(",".join(map(str, (acc, perf, r1p, r1r, r1f, r2p, r2r, r2f, rLp, rLr, rLf, bleurt, gptscore))))
+
+
+for setting in settings:
+ print(f"==== {setting} ====")
+ for model in models:
+ result_path = REPO_ROOT / f"results/score-{setting}-{model}.json"
+ print(f"{model},", end="")
+ print_one(result_path)
diff --git a/reference/paper_benchmarks/scores_to_web_fmt.py b/reference/paper_benchmarks/scores_to_web_fmt.py
new file mode 100644
index 0000000..1cc8149
--- /dev/null
+++ b/reference/paper_benchmarks/scores_to_web_fmt.py
@@ -0,0 +1,80 @@
+"""Read all the scores and output them as CSV so I don't have to type them all."""
+
+import json
+
+from bench.utils import REPO_ROOT
+
+settings = ["closedbook", "openbook", "wiki-provided"]
+model_info = {
+ "gpt-4": {
+ "name": "GPT-4",
+ "authors": "OpenAI",
+ "url": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
+ "citation": "OpenAI, 2023",
+ "type": "FOUNDATION",
+ "context": 8192,
+ },
+ "gpt-4-turbo": {
+ "name": "GPT-4-turbo",
+ "authors": "OpenAI",
+ "url": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
+ "citation": "OpenAI, 2023",
+ "type": "FOUNDATION",
+ "context": 128000,
+ },
+ "gpt-3.5-turbo": {
+ "name": "GPT-3.5-turbo",
+ "authors": "OpenAI",
+ "url": "https://platform.openai.com/docs/models/gpt-3-5-turbo",
+ "citation": "OpenAI, 2023",
+ "type": "FOUNDATION",
+ "context": 16384,
+ },
+ "llama-chat": {
+ "name": "LLaMA 2 70B",
+ "authors": "Meta",
+ "url": "https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/",
+ "citation": "Touvron et al., 2023",
+ "type": "FOUNDATION",
+ "context": 4096,
+ },
+ "mistral-chat": {
+ "name": "Mistral-7B",
+ "authors": "Mistral AI",
+ "url": "https://mistral.ai/news/announcing-mistral-7b/",
+ "citation": "Jiang et al., 2023",
+ "type": "FOUNDATION",
+ "context": 32000,
+ },
+ "mixtral": {
+ "name": "Mixtral-8x7B",
+ "authors": "Mistral AI",
+ "url": "https://mistral.ai/news/mixtral-of-experts/",
+ "citation": "Jiang et al., 2024",
+ "type": "FOUNDATION",
+ "context": 32000,
+ },
+ "claude": {
+ "name": "Claude 2.1",
+ "authors": "Anthropic",
+ "url": "https://www.anthropic.com/news/claude-2-1",
+ "citation": "Anthropic, 2023",
+ "type": "FOUNDATION",
+ "context": 200000,
+ },
+}
+
+for setting in settings:
+ results = []
+ for model, info in model_info.items():
+ result_path = REPO_ROOT / f"results/score-{setting}-{model}.json"
+ with open(result_path) as f:
+ scores = json.load(f)
+ scores["rouge"].pop("rougeLsum", None)
+ scores["acc"]["loose"] = scores["acc"].pop("acc")
+ scores["acc"]["strict"] = scores["acc"].pop("perfect")
+ scores.update(info)
+ results.append(scores)
+
+ with open(REPO_ROOT / f"results/web-{setting}.json", "w") as f:
+ json.dump(results, f)
diff --git a/reference/paper_benchmarks/validate_results.py b/reference/paper_benchmarks/validate_results.py
new file mode 100644
index 0000000..c05d128
--- /dev/null
+++ b/reference/paper_benchmarks/validate_results.py
@@ -0,0 +1,73 @@
+"""
+Ensure that each result has all of the answers for the test set.
+Write the IDs of the missing configurations to config/{model}-{setting}.txt, one per line
+"""
+
+import json
+from pathlib import Path
+
+from bench.utils import REPO_ROOT, load_questions, load_results
+
+
+def fix_one(fp: Path):
+ fn = fp.stem
+ prompt_file = open(REPO_ROOT / f"results/prompts-{fn}.json", "w")
+ results_file = open(REPO_ROOT / f"results/results-{fn}.jsonl", "w")
+ extras_file = open(REPO_ROOT / f"results/extra-{fn}.jsonl", "w")
+
+ with open(fp) as f:
+ data = json.load(f)
+
+ prompts = []
+ for result in data:
+ smol = {"id": result["id"], "answer": result["answer"], "question": result["question"]}
+ results_file.write(json.dumps(smol))
+ results_file.write("\n")
+ # extras
+ extras_file.write(json.dumps(result))
+ extras_file.write("\n")
+ # prompts
+ prompts.append({"id": result["id"], "prompt": result["prompt"]})
+
+ json.dump(prompts, prompt_file)
+
+ prompt_file.close()
+ results_file.close()
+ extras_file.close()
+
+
+def validate(questions, results) -> list[str]:
+ """Given the questions and results, output a list of missing IDs"""
+ question_ids_list = [q["id"] for q in questions]
+ question_ids = set(q["id"] for q in questions)
+ result_ids = set(r["id"] for r in results if r["answer"])
+ return sorted(question_ids.difference(result_ids), key=lambda i: question_ids_list.index(i))
+
+
+def validate_one(questions, fp):
+ fn = fp.stem.removeprefix("results-")
+ results = load_results(fp)
+ missing = validate(questions, results)
+ with open(REPO_ROOT / f"config/{fn}.txt", "w") as f:
+ f.write("\n".join(missing))
+ print(f"{fp}: {len(questions) - len(missing)} / {len(questions)}")
+ # remove the ones that errored
+ with open(fp, "w") as f:
+ for r in results:
+ if not r["answer"]:
+ continue
+ f.write(json.dumps(r))
+ f.write("\n")
+
+
+def main():
+ questions = load_questions()
+ # for result_path in (REPO_ROOT / "results-old").glob("*.json"):
+ # fix_one(result_path)
+
+ for result_path in (REPO_ROOT / "results").glob("results-*.jsonl"):
+ validate_one(questions, result_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reference/paper_benchmarks/viz-scores.ipynb b/reference/paper_benchmarks/viz-scores.ipynb
new file mode 100644
index 0000000..43e250a
--- /dev/null
+++ b/reference/paper_benchmarks/viz-scores.ipynb
@@ -0,0 +1,225 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "outputs": [],
+ "source": [
+ "# load the data\n",
+ "import json\n",
+ "\n",
+ "from bench.utils import REPO_ROOT\n",
+ "\n",
+ "settings = [\"closedbook\", \"openbook\", \"wiki-provided\"]\n",
+ "setting_names = [\"Closed Book\", \"Open Book\", \"Evidence Provided\"]\n",
+ "models = [\"llama-chat\", \"gpt-4\", \"gpt-3.5-turbo\", \"mistral-chat\", \"mixtral\", \"gpt-4-turbo\", \"claude\"]\n",
+ "model_names = [\"LLaMA 2\", \"GPT-4\", \"GPT-3.5-turbo\", \"Mistral-7B\", \"Mixtral-8x7B\", \"GPT-4-turbo\", \"Claude 2.1\"]\n",
+ "\n",
+ "results_loose = {k: [] for k in setting_names}\n",
+ "results_model = {k: [] for k in setting_names}\n",
+ "\n",
+ "# def print_one(fp):\n",
+ "# with open(fp) as f:\n",
+ "# scores = json.load(f)\n",
+ "# acc = scores[\"acc\"][\"acc\"]\n",
+ "# perf = scores[\"acc\"][\"perfect\"]\n",
+ "# r1p = scores[\"rouge\"][\"rouge1\"][\"precision\"]\n",
+ "# r1r = scores[\"rouge\"][\"rouge1\"][\"recall\"]\n",
+ "# r1f = scores[\"rouge\"][\"rouge1\"][\"fscore\"]\n",
+ "# r2p = scores[\"rouge\"][\"rouge2\"][\"precision\"]\n",
+ "# r2r = scores[\"rouge\"][\"rouge2\"][\"recall\"]\n",
+ "# r2f = scores[\"rouge\"][\"rouge2\"][\"fscore\"]\n",
+ "# rLp = scores[\"rouge\"][\"rougeL\"][\"precision\"]\n",
+ "# rLr = scores[\"rouge\"][\"rougeL\"][\"recall\"]\n",
+ "# rLf = scores[\"rouge\"][\"rougeL\"][\"fscore\"]\n",
+ "# bleurt = scores[\"bleurt\"]\n",
+ "# gptscore = scores[\"gpt\"]\n",
+ "# print(\",\".join(map(str, (acc, perf, r1p, r1r, r1f, r2p, r2r, r2f, rLp, rLr, rLf, bleurt, gptscore))))\n",
+ "\n",
+ "\n",
+ "for setting, setting_name in zip(settings, setting_names):\n",
+ " for model, model_name in zip(models, model_names):\n",
+ " result_path = REPO_ROOT / f\"results/score-{setting}-{model}.json\"\n",
+ " with open(result_path) as f:\n",
+ " scores = json.load(f)\n",
+ " loose = scores[\"acc\"][\"acc\"]\n",
+ " results_loose[setting_name].append(loose)\n",
+ " results_model[setting_name].append(scores[\"gpt\"])\n"
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "53c8b7b5d53696f"
+ },
+ {
+ "cell_type": "code",
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt"
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "bc8afd1110f9c495"
+ },
+ {
+ "cell_type": "code",
+ "outputs": [],
+ "source": [
+ "# liam code\n",
+ "def show_values(axs, orient=\"v\", space=.01, label_thresh=0):\n",
+ " def _single(ax):\n",
+ " if orient == \"v\":\n",
+ " for p in ax.patches:\n",
+ " _x = p.get_x() + p.get_width() / 2\n",
+ " _y = p.get_y() + p.get_height() + (p.get_height()*0.01)\n",
+ " value = '{:.1f}'.format(p.get_height()*100)\n",
+ " if p.get_height() >= label_thresh:\n",
+ " ax.text(_x, _y, value, ha=\"center\", fontsize=11) #, rotation=40)\n",
+ "\n",
+ " if isinstance(axs, np.ndarray):\n",
+ " for idx, ax in np.ndenumerate(axs):\n",
+ " _single(ax)\n",
+ " else:\n",
+ " _single(axs)\n",
+ "\n",
+ "# # Put the intended figsize here\n",
+ "# fig, ax = plt.subplots(figsize=(5,2.8))\n",
+ "# \n",
+ "# # Put your dataframe here \n",
+ "# sns.barplot(ax=ax, data=df, y='accuracy', x='model', hue='chat')\n",
+ "# \n",
+ "# # Can customize legend here\n",
+ "# ax.legend(loc='upper right', ncol=2, fontsize=12, columnspacing=0.5, labelspacing=0.3, handlelength=1.5, handletextpad=0.4, fancybox=False)\n",
+ "\n",
+ "\n",
+ "# # Set size of text and other things\n",
+ "# ax.xaxis.set_tick_params(labelsize=14)\n",
+ "# \n",
+ "# # Set no printing of axis label and set y limits\n",
+ "# ax.set(xlabel=None)\n",
+ "# ax.set(ylim=(0.0, 0.6))"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2024-02-15T17:11:49.221278Z",
+ "start_time": "2024-02-15T17:11:49.207232Z"
+ }
+ },
+ "id": "78b8da145cbb5620",
+ "execution_count": 36
+ },
+ {
+ "cell_type": "code",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "