diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 0e88e0d6..a95ff52e 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -1,13 +1,17 @@ import json import logging +import os import random from abc import ABC, abstractmethod +from pathlib import Path from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts.example_selector import MaxMarginalRelevanceExampleSelector from langchain.vectorstores import FAISS +from pooch import Decompress, Pooch, retrieve, Untar, Unzip + class DatasetBase(ABC): """ @@ -20,7 +24,9 @@ class DatasetBase(ABC): Attributes ---------- - None + data_dir : str + Base path of data containing all datasets. Defaults to "data" in the current + working directory. Methods ------- @@ -46,12 +52,11 @@ class DatasetBase(ABC): """ - def __init__(self, **kwargs): - pass + def __init__(self, data_dir="data", **kwargs): + self.data_dir = data_dir - @staticmethod @abstractmethod - def metadata(): + def metadata(self): """ Returns the dataset's metadata @@ -284,3 +289,151 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True): examples = [self._destringify_sample(sample) for sample in examples] yield examples + + def download_dataset(self, download_url=None): + """ + Utility method to download a dataset if not present locally on disk. + Can handle datasets of types *.zip, *.tar, *.tar.gz, *.tar.bz2, *.tar.xz. + + Arguments + --------- + download_url : str + The url to the dataset. If not provided, falls back to the `download_url` + provided by the Dataset's metadata. If missing, falls back to a default + server specified by the environment variable `DEFAULT_DOWNLOAD_URL` + + Returns + ------- + download_succeeded : bool + Returns True if the dataset is already present on disk, or if download + + extraction was successful. + """ + + def decompress(fname, action, pup): + """ + Post-processing hook to automatically detect the type of archive and + call the correct processor (UnZip, Untar, Decompress) + + Arguments + --------- + fname : str + Full path of the zipped file in local storage + action : str + One of "download" (file doesn't exist and will download), + "update" (file is outdated and will download), and + "fetch" (file exists and is updated so no download). + pup : Pooch + The instance of Pooch that called the processor function. + + Returns + ------- + fnames : list + List of all extracted files + + """ + # Default where the downloaded file is not a container/archive + fnames = [fname] + + extract_dir = self.__class__.__name__ + + if fname.endswith(".tar.xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".tar"): + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".zip"): + extractor = Unzip(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + + return fnames + + # Priority: + # Fn Argument + # Dataset metadata["download_url"] + # DEFAULT_DOWNLOAD_URL/Dataset_name.zip + download_urls = [] + if download_url is not None: + download_urls.append(download_url) + + metadata_url = self.metadata().get("download_url", None) + if metadata_url is not None: + download_urls.append(metadata_url) + + default_url = os.getenv("DEFAULT_DOWNLOAD_URL") + if default_url is not None: + if default_url.endswith("/"): + default_url = default_url[:-1] + default_url = f"{default_url}/{self.__class__.__name__}.zip" + download_urls.append(default_url) + + # Try downloading from available links in order of priority + for download_url in download_urls: + extension = ".zip" + supported_extensions = [ + ".tar.xz", + ".tar.bz2", + ".tar.gz", + ".xz", + ".bz2", + ".gz", + ".tar", + ".zip", + ] + + for ext in supported_extensions: + if download_url.endswith(ext): + extension = ext + break + try: + logging.info(f"Trying {download_url}") + retrieve( + download_url, + known_hash=None, + fname=f"{self.__class__.__name__}{extension}", + path=self.data_dir, + progressbar=True, + processor=decompress, + ) + # If it was a *.tar.* file, we can safely delete the + # intermediate *.tar file + if extension in supported_extensions[:3]: + tar_file_path = ( + Path(self.data_dir) / f"{self.__class__.__name__}.tar" + ) + tar_file_path.unlink() + return True + except Exception as e: + logging.warning(f"Failed to download: {e}") + continue + + logging.warning(f"Failed to download dataset") + + return False diff --git a/setup.cfg b/setup.cfg index 6df97ac4..2e33d157 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ install_requires = nltk==3.8.1 openai==0.27.7 pandas==2.0.2 + pooch==1.7.0 python-dotenv==1.0.0 scikit-learn==1.2.2 tenacity==8.2.2 diff --git a/tests/datasets/archives/MockDataset.tar b/tests/datasets/archives/MockDataset.tar new file mode 100644 index 00000000..18f6f23c Binary files /dev/null and b/tests/datasets/archives/MockDataset.tar differ diff --git a/tests/datasets/archives/MockDataset.tar.bz2 b/tests/datasets/archives/MockDataset.tar.bz2 new file mode 100644 index 00000000..a1e0ff88 Binary files /dev/null and b/tests/datasets/archives/MockDataset.tar.bz2 differ diff --git a/tests/datasets/archives/MockDataset.tar.gz b/tests/datasets/archives/MockDataset.tar.gz new file mode 100644 index 00000000..cf1c3bf2 Binary files /dev/null and b/tests/datasets/archives/MockDataset.tar.gz differ diff --git a/tests/datasets/archives/MockDataset.tar.xz b/tests/datasets/archives/MockDataset.tar.xz new file mode 100644 index 00000000..9f82e862 Binary files /dev/null and b/tests/datasets/archives/MockDataset.tar.xz differ diff --git a/tests/datasets/archives/MockDataset.zip b/tests/datasets/archives/MockDataset.zip new file mode 100644 index 00000000..d660da1c Binary files /dev/null and b/tests/datasets/archives/MockDataset.zip differ diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py new file mode 100644 index 00000000..3a9e471e --- /dev/null +++ b/tests/datasets/test_download_and_caching.py @@ -0,0 +1,220 @@ +import http.server +import threading +import unittest + +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from llmebench.datasets.dataset_base import DatasetBase + + +class ArchiveHandler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, directory="tests/datasets/archives") + + +class SignalingHTTPServer(http.server.HTTPServer): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.ready_event = threading.Event() + + def service_actions(self): + self.ready_event.set() + + +class MockDataset(DatasetBase): + def metadata(self): + return {} + + def get_data_sample(self): + return {"input": "input", "label": "label"} + + def load_data(self, data_path): + return [self.get_data_sample() for _ in range(100)] + + +class MockDatasetWithDownloadURL(MockDataset): + def __init__(self, port, filename="MockDataset.zip", **kwargs): + self.port = port + self.filename = filename + super(MockDatasetWithDownloadURL, self).__init__(**kwargs) + + def metadata(self): + return {"download_url": f"http://localhost:{self.port}/{self.filename}"} + + +class TestDatasetAutoDownload(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.httpd = SignalingHTTPServer(("", 0), ArchiveHandler) + cls.port = cls.httpd.server_address[1] + + cls.test_server = threading.Thread(target=cls.httpd.serve_forever, daemon=True) + cls.test_server.start() + cls.httpd.ready_event.wait() + + @classmethod + def tearDownClass(cls): + if cls.httpd: + cls.httpd.shutdown() + cls.httpd.server_close() + cls.test_server.join() + + def check_downloaded(self, data_dir, dataset_name, extension): + downloaded_files = list(data_dir.iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn(f"{dataset_name}.{extension}", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn(f"{dataset_name}", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + + def test_auto_download_zip(self): + "Test automatic downloading and extraction of *.zip datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.zip" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") + + def test_auto_download_tar(self): + "Test automatic downloading and extraction of *.tar datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar") + + def test_auto_download_tar_gz(self): + "Test automatic downloading and extraction of *.tar.gz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.gz" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.gz") + + def test_auto_download_tar_bz2(self): + "Test automatic downloading and extraction of *.tar.bz2 datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.bz2" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.bz2") + + def test_auto_download_tar_xz(self): + "Test automatic downloading and extraction of *.tar.xz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.xz" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.xz") + + def test_auto_download_default_url(self): + "Test automatic downloading when download url is not provided" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + with patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": f"http://localhost:{self.port}/", + }, + ): + self.assertTrue(dataset.download_dataset()) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") + + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", + }, + ) + def test_auto_download_metadata_url(self): + "Test automatic downloading when download url is provided in metadata" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name, port=self.port) + self.assertTrue(dataset.download_dataset()) + + self.check_downloaded(Path(data_dir.name), "MockDatasetWithDownloadURL", "zip") + + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", + }, + ) + def test_auto_download_non_existent(self): + "Test automatic downloading when dataset is not actually available" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL( + data_dir=data_dir.name, port=self.port, filename="InvalidDataset.zip" + ) + self.assertFalse( + dataset.download_dataset( + download_url="http://invalid.llmebench-server.org/Dataset.zip" + ) + ) + + +class TestDatasetCaching(unittest.TestCase): + def test_cache_existing_file(self): + "Test if an existing file _does not_ trigger a download" + + data_dir = TemporaryDirectory() + + # Copy a archive to the download location + archive_file = Path("tests/datasets/archives/MockDataset.zip") + copy_archive_file = Path(data_dir.name) / "MockDataset.zip" + copy_archive_file.write_bytes(archive_file.read_bytes()) + + # download_dataset should not reach out to the invalid server, + # since file is present locally + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url="http://invalid.llmebench-server.org/ExistingData.zip" + ) + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 7d86fc4f..800866c5 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -17,7 +17,7 @@ class MockDataset(DatasetBase): - def metadata(): + def metadata(self): return {} def get_data_sample(self):