Skip to content

Commit

Permalink
chore: upload paper reference impl
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Feb 28, 2024
1 parent 0206b97 commit fd62fe2
Show file tree
Hide file tree
Showing 25 changed files with 1,924 additions and 0 deletions.
7 changes: 7 additions & 0 deletions reference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Reference Implementations

This directory contains reference implementations of experiments run in the FanOutQA paper. These implementations are
provided "as-is", with no guarantees on usability or support provided. Note the code quality of these implementations
may be lower than the library code, as these are earlier explorations run a) on a deadline and b) to develop the
interfaces that were eventually exposed in this library.

161 changes: 161 additions & 0 deletions reference/paper_benchmarks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

**.DS_Store
results/extra-*
results/prompts-*
results-old/
.wikicache/
tmp*/
BLEURT-20/
.llmcache/
28 changes: 28 additions & 0 deletions reference/paper_benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# benchmarks

make sure to do `python -m spacy download en_core_web_sm`

closedbook: just the q and 0 shot prompt to output short text

openbook: like closedbook but you can search wikipedia wowza

wiki provided: chunk the known evidence 1024 characters at a time (preferring splitting at paragraph and sentence
boundaries), retrieve as many passages fit into context window as possible, use bm25

run:

```shell
WIKI_CACHE_DIR=/nlp/data/andrz/fanoutqa-bench/.wikicache python run_openbook.py -m mistral-chat -n 3
for i in slurm/*; do sbatch $i; done
```

# eval

install:

```shell
wget https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip
unzip BLEURT-20.zip
rm BLEURT-20.zip
python -m spacy download en_core_web_sm
```
Empty file.
111 changes: 111 additions & 0 deletions reference/paper_benchmarks/bench/engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os

import aiolimiter
import torch
from kani import AIFunction, ChatMessage
from kani.engines.anthropic import AnthropicEngine
from kani.engines.base import BaseEngine
from kani.engines.huggingface.llama2 import LlamaEngine
from kani.engines.openai import OpenAIEngine
from kani.engines.openai.models import ChatCompletion

# openai
ft_org_id = "org-dxkPAvuWroRmYw7z0VhqMC5S"
lab_org_id = "org-bgAXfs8WdU5942SLngg0OGpd"

# hf
common_hf_model_args = dict(
device_map="auto",
torch_dtype=torch.float16,
)

if mcd := os.getenv("MODEL_CACHE_DIR"):
common_hf_model_args["cache_dir"] = mcd


def can_parallel(engine):
return isinstance(engine, OpenAIEngine)


def get_engine(name: str, ctx_size=None) -> BaseEngine:
if name == "gpt-4":
return RatelimitedOpenAIEngine(
model="gpt-4-0613",
organization=lab_org_id,
max_context_size=ctx_size or 8192,
tpm_limiter=aiolimiter.AsyncLimiter(300000),
temperature=0,
)
elif name == "gpt-4-turbo":
return RatelimitedOpenAIEngine(
model="gpt-4-0125-preview",
organization=ft_org_id,
max_context_size=ctx_size or 128000,
tpm_limiter=aiolimiter.AsyncLimiter(400000),
temperature=0,
)
elif name == "gpt-3.5-turbo":
return RatelimitedOpenAIEngine(
model="gpt-3.5-turbo-1106",
organization=ft_org_id,
max_context_size=ctx_size or 16000,
tpm_limiter=aiolimiter.AsyncLimiter(2000000),
temperature=0,
)
elif name == "llama-chat":
return LlamaEngine(
"meta-llama/Llama-2-70b-chat-hf",
max_context_size=ctx_size or 4096,
model_load_kwargs=common_hf_model_args,
use_auth_token=True,
do_sample=False,
)
elif name == "mistral-chat":
return LlamaEngine(
"mistralai/Mistral-7B-Instruct-v0.2",
max_context_size=ctx_size or 32768,
model_load_kwargs=common_hf_model_args,
do_sample=False,
)
elif name == "mixtral":
return LlamaEngine(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
max_context_size=ctx_size or 32768,
model_load_kwargs=common_hf_model_args,
do_sample=False,
)
elif name == "longllama":
return LlamaEngine(
"syzymon/long_llama_3b_instruct",
max_context_size=ctx_size or 512000,
tokenizer_kwargs={"trust_remote_code": True},
model_load_kwargs={"trust_remote_code": True, **common_hf_model_args},
do_sample=False,
)
elif name == "claude":
return AnthropicEngine(
model="claude-2.1",
max_context_size=ctx_size or 200000,
temperature=0,
)
else:
raise ValueError("Invalid model name")


class RatelimitedOpenAIEngine(OpenAIEngine):
def __init__(
self, *args, rpm_limiter: aiolimiter.AsyncLimiter = None, tpm_limiter: aiolimiter.AsyncLimiter = None, **kwargs
):
super().__init__(*args, **kwargs)
self.rpm_limiter = rpm_limiter
self.tpm_limiter = tpm_limiter

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> ChatCompletion:
if self.rpm_limiter:
await self.rpm_limiter.acquire()
if self.tpm_limiter:
n_toks = self.function_token_reserve(functions) + sum(self.message_len(m) for m in messages)
await self.tpm_limiter.acquire(n_toks)
return await super().predict(messages, functions, **hyperparams)
Loading

0 comments on commit fd62fe2

Please sign in to comment.