diff --git a/README.md b/README.md index b2f6732..62683e5 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ class DevQuestion: question: str # the top-level question to answer decomposition: list[DevSubquestion] # human-written decomposition of the question answer: dict[str, Primitive] | list[Primitive] | Primitive + necessary_evidence: list[Evidence] categories: list[str] @@ -115,10 +116,12 @@ are two main functions to interface with Wikipedia: To save on time waiting for requests and computation power (both locally and on Wikipedia's end), this package aggressively caches retrieved Wikipedia pages. By default, this cache is located in `~/.cache/fanoutqa/wikicache`. -We provide many cached pages you can prepopulate this cache with, by using the following commands: +We provide many cached pages (~9GB) you can prepopulate this cache with, by using the following commands: ```shell -mkdir -p ~/.cache/fanoutqa/wikicache +mkdir -p ~/.cache/fanoutqa +wget -O ~/.cache/fanoutqa/wikicache.tar.gz https://datasets.mechanus.zhu.codes/fanoutqa/wikicache.tar.gz +tar -xzf ~/.cache/fanoutqa/wikicache.tar.gz ``` ## Evaluation diff --git a/fanoutqa/models.py b/fanoutqa/models.py index f98966b..7d64fdf 100644 --- a/fanoutqa/models.py +++ b/fanoutqa/models.py @@ -2,6 +2,7 @@ from typing import Optional, Union Primitive = Union[bool, int, float, str] +AnswerType = Union[dict[str, Primitive], list[Primitive], Primitive] @dataclass @@ -32,7 +33,7 @@ class DevSubquestion: id: str question: str decomposition: list["DevSubquestion"] - answer: Union[dict[str, Primitive], list[Primitive], Primitive] + answer: AnswerType """the answer to this subquestion""" depends_on: list[str] @@ -64,7 +65,7 @@ class DevQuestion: """the top-level question to answer""" decomposition: list[DevSubquestion] """human-written decomposition of the question""" - answer: Union[dict[str, Primitive], list[Primitive], Primitive] + answer: AnswerType categories: list[str] @classmethod @@ -78,6 +79,18 @@ def from_dict(cls, d): categories=d["categories"], ) + @property + def necessary_evidence(self) -> list[Evidence]: + """A list of all the evidence used by human annotators to answer the question.""" + + def walk_evidences(subqs): + for subq in subqs: + if subq.evidence: + yield subq.evidence + yield from walk_evidences(subq.decomposition) + + return list(walk_evidences(self.decomposition)) + @dataclass class TestQuestion: diff --git a/fanoutqa/utils.py b/fanoutqa/utils.py index d559fee..0477e8c 100644 --- a/fanoutqa/utils.py +++ b/fanoutqa/utils.py @@ -10,6 +10,8 @@ AnyPath: TypeAlias = Union[str, bytes, os.PathLike] PKG_ROOT = Path(__file__).parent +CACHE_DIR = Path("~/.cache/fanoutqa") +CACHE_DIR.mkdir(exist_ok=True, parents=True) DATASET_EPOCH = datetime.datetime(year=2023, month=11, day=20, tzinfo=datetime.timezone.utc) """The day before which to get revisions from Wikipedia, to ensure that the contents of pages don't change over time.""" diff --git a/fanoutqa/wiki.py b/fanoutqa/wiki.py index bdd7197..8b20605 100644 --- a/fanoutqa/wiki.py +++ b/fanoutqa/wiki.py @@ -3,16 +3,15 @@ import functools import logging import urllib.parse -from pathlib import Path import httpx from .models import Evidence -from .utils import DATASET_EPOCH, markdownify +from .utils import CACHE_DIR, DATASET_EPOCH, markdownify USER_AGENT = "fanoutqa/1.0.0 (andrz@seas.upenn.edu)" -CACHE_DIR = Path("~/.cache/fanoutqa/wikicache") -CACHE_DIR.mkdir(exist_ok=True, parents=True) +WIKI_CACHE_DIR = CACHE_DIR / "wikicache" +WIKI_CACHE_DIR.mkdir(exist_ok=True, parents=True) log = logging.getLogger(__name__) wikipedia = httpx.Client(base_url="https://en.wikipedia.org/w/api.php", headers={"User-Agent": USER_AGENT}) @@ -70,7 +69,7 @@ def wiki_search(query: str, results=10) -> list[Evidence]: def wiki_content(doc: Evidence) -> str: """Get the page content in markdown, including tables and infoboxes, appropriate for displaying to an LLM.""" # get the cached content, if available - cache_filename = CACHE_DIR / f"{doc.pageid}-dated.md" + cache_filename = WIKI_CACHE_DIR / f"{doc.pageid}-dated.md" if cache_filename.exists(): return cache_filename.read_text()