From 5a7906b3852b58e29dd4e110de2ad1275d5f4858 Mon Sep 17 00:00:00 2001 From: Fahim Dalvi Date: Mon, 11 Sep 2023 14:10:29 +0300 Subject: [PATCH] Implement download and caching utility for datasets (#214) This commit adds support for downloading and caching datasets for future use. Supports a variety of archives, include zip, tar, tar.gz and more. The utility method is not automatically called currently. * Initial auto-download implementation * Improve auto-downloader implementation * Add tests for auto downloader * Add tests for caching mechanism * Add test for metadata download url usage over environment variable * Generalize port for test server * Add tests for tar.bz2 and tar.xz files * Add test for non-existent datasets * Clean up code and add docstrings * Fix incorrect param handling in dataset_base init --- llmebench/datasets/dataset_base.py | 163 ++++++++++++++- setup.cfg | 1 + tests/datasets/archives/MockDataset.tar | Bin 0 -> 8192 bytes tests/datasets/archives/MockDataset.tar.bz2 | Bin 0 -> 869 bytes tests/datasets/archives/MockDataset.tar.gz | Bin 0 -> 677 bytes tests/datasets/archives/MockDataset.tar.xz | Bin 0 -> 660 bytes tests/datasets/archives/MockDataset.zip | Bin 0 -> 1469 bytes tests/datasets/test_download_and_caching.py | 220 ++++++++++++++++++++ tests/test_benchmark.py | 2 +- 9 files changed, 380 insertions(+), 6 deletions(-) create mode 100644 tests/datasets/archives/MockDataset.tar create mode 100644 tests/datasets/archives/MockDataset.tar.bz2 create mode 100644 tests/datasets/archives/MockDataset.tar.gz create mode 100644 tests/datasets/archives/MockDataset.tar.xz create mode 100644 tests/datasets/archives/MockDataset.zip create mode 100644 tests/datasets/test_download_and_caching.py diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 0e88e0d6..a95ff52e 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -1,13 +1,17 @@ import json import logging +import os import random from abc import ABC, abstractmethod +from pathlib import Path from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts.example_selector import MaxMarginalRelevanceExampleSelector from langchain.vectorstores import FAISS +from pooch import Decompress, Pooch, retrieve, Untar, Unzip + class DatasetBase(ABC): """ @@ -20,7 +24,9 @@ class DatasetBase(ABC): Attributes ---------- - None + data_dir : str + Base path of data containing all datasets. Defaults to "data" in the current + working directory. Methods ------- @@ -46,12 +52,11 @@ class DatasetBase(ABC): """ - def __init__(self, **kwargs): - pass + def __init__(self, data_dir="data", **kwargs): + self.data_dir = data_dir - @staticmethod @abstractmethod - def metadata(): + def metadata(self): """ Returns the dataset's metadata @@ -284,3 +289,151 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True): examples = [self._destringify_sample(sample) for sample in examples] yield examples + + def download_dataset(self, download_url=None): + """ + Utility method to download a dataset if not present locally on disk. + Can handle datasets of types *.zip, *.tar, *.tar.gz, *.tar.bz2, *.tar.xz. + + Arguments + --------- + download_url : str + The url to the dataset. If not provided, falls back to the `download_url` + provided by the Dataset's metadata. If missing, falls back to a default + server specified by the environment variable `DEFAULT_DOWNLOAD_URL` + + Returns + ------- + download_succeeded : bool + Returns True if the dataset is already present on disk, or if download + + extraction was successful. + """ + + def decompress(fname, action, pup): + """ + Post-processing hook to automatically detect the type of archive and + call the correct processor (UnZip, Untar, Decompress) + + Arguments + --------- + fname : str + Full path of the zipped file in local storage + action : str + One of "download" (file doesn't exist and will download), + "update" (file is outdated and will download), and + "fetch" (file exists and is updated so no download). + pup : Pooch + The instance of Pooch that called the processor function. + + Returns + ------- + fnames : list + List of all extracted files + + """ + # Default where the downloaded file is not a container/archive + fnames = [fname] + + extract_dir = self.__class__.__name__ + + if fname.endswith(".tar.xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".tar"): + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".zip"): + extractor = Unzip(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + + return fnames + + # Priority: + # Fn Argument + # Dataset metadata["download_url"] + # DEFAULT_DOWNLOAD_URL/Dataset_name.zip + download_urls = [] + if download_url is not None: + download_urls.append(download_url) + + metadata_url = self.metadata().get("download_url", None) + if metadata_url is not None: + download_urls.append(metadata_url) + + default_url = os.getenv("DEFAULT_DOWNLOAD_URL") + if default_url is not None: + if default_url.endswith("/"): + default_url = default_url[:-1] + default_url = f"{default_url}/{self.__class__.__name__}.zip" + download_urls.append(default_url) + + # Try downloading from available links in order of priority + for download_url in download_urls: + extension = ".zip" + supported_extensions = [ + ".tar.xz", + ".tar.bz2", + ".tar.gz", + ".xz", + ".bz2", + ".gz", + ".tar", + ".zip", + ] + + for ext in supported_extensions: + if download_url.endswith(ext): + extension = ext + break + try: + logging.info(f"Trying {download_url}") + retrieve( + download_url, + known_hash=None, + fname=f"{self.__class__.__name__}{extension}", + path=self.data_dir, + progressbar=True, + processor=decompress, + ) + # If it was a *.tar.* file, we can safely delete the + # intermediate *.tar file + if extension in supported_extensions[:3]: + tar_file_path = ( + Path(self.data_dir) / f"{self.__class__.__name__}.tar" + ) + tar_file_path.unlink() + return True + except Exception as e: + logging.warning(f"Failed to download: {e}") + continue + + logging.warning(f"Failed to download dataset") + + return False diff --git a/setup.cfg b/setup.cfg index 6df97ac4..2e33d157 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ install_requires = nltk==3.8.1 openai==0.27.7 pandas==2.0.2 + pooch==1.7.0 python-dotenv==1.0.0 scikit-learn==1.2.2 tenacity==8.2.2 diff --git a/tests/datasets/archives/MockDataset.tar b/tests/datasets/archives/MockDataset.tar new file mode 100644 index 0000000000000000000000000000000000000000..18f6f23c81bae5fc818f2a284b5e52e506e8aff8 GIT binary patch literal 8192 zcmeHM&uSV$7@xF+(!Hcm+CxtZf!^HA?9S}Eh=+Bx`9oG+-J~&>Fx8cqs4?pVjr0|I z%Aqe%=o926dMotU7bx^l`UIWTfMGGFN>-9~2IQN6-(TnZ_?IT}A_LTAdy?TGtOEbAGK} zkL+Dv4n2V*=OCvKWh}pZ0v-Y2(5hT0%X{+G2=~BK{!sx~N^{EBODBKll;wG)(aj;p zcRs&}a6abiaQ_(quc{|4g?2hkUun`VJL~$j8fE@!v3$kzlhX|VR>t>QwAvgVE@4hm z0K6EVou1KvwppX?gRdWc)xvAfKYdArbQyenefe4HTR$#W)-S%j|9u?p&OT1pAWWi^ zj_WX4tJYKHj)ix5Ak;^1Z8$wgVCf0~PyRNu7inPDs)0!Ww z#a9P-YD~+2gfNWcKj9Ss$o;sT59R;o`L98_#TqR?jdcD(bu(#b3N{HL*wnFvP*UFB z%vdFRV{yCzmAi2rMp#gVQswCZ<=4_u)I?f3jO(kO9 zWNcJGwB$;QNMWI4{s_cZ^~W~dQvbsaFjD_Dq^qL;8%Hp{+3@?1PrQ(P&m*85Fu%)U zg?}OLo7VqL!hb3uBL5fibOipJIn4hk{)?>++>-wj@E`vzNEi5T=7%{4gTQ|gOe^r8 tz<+b((L-g4zxxBX_J0%bAOC+q^nZ(ZQv&}jBFaLCh=7QIh`{}ez#oY&Vsiih literal 0 HcmV?d00001 diff --git a/tests/datasets/archives/MockDataset.tar.bz2 b/tests/datasets/archives/MockDataset.tar.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..a1e0ff88f73afe5221c0bc66886a72c2db53602b GIT binary patch literal 869 zcmV-r1DgCoT4*^jL0KkKSp`d?FaQBHfB*lu%aBULm+9O00STcATj_P01XC)gFptE28KY>KmZ0n20&y03Y3Uxpp6qI z6H^*wj}+d5PejS5lo}c|#M+t+38n-#r!?)I*Us!No@EZ=cay;Ho+SPoCgYW^l+ZF$x8m#3+%D$;5X++^rn<_-Oq=@J!=_o?XTpSuT-Dl{D{3T|R zP)(A9kwEFFW~e@O;!(7zTGdTeZ(6ohXK9+q!sa{AVPo1+8R{gwoYbOPOw{l$pF&=C znH)`56ZTfT!roZGbCcG4jU((c3k=6Cya;QP3K@6lN>lD@6j4W=d8lF&n8wIf6B#(t ziQeyizpp>Vc`uXlZM&t!%}mXsB5XXY0mM+q!l^eCwCpIhOOmHftPE*W9jlF_BO@gj z63UXyC!t#H#>3vOO|rzkd0Orkh74+w#J;tKFeRUPbuXplsIe-#cp476!m1~6fn~-j zFg}I?wwpostz2ZIG{b>)IE%?@>jQZ$X_|{9wJ^z^TCBvOp0!blVR4M8si0;Wj0RJn zX}<+;nMFz}WH%TJhU>=4tu)iV^}B~uwDdn@+G`&xWX?{KX*`a*IgGTm!r@Jmpj78O z?;7*;HQQv(&zbXQzoTr@cH!*66pI?n4U! vEltN6MCiT;!1_)yi2n*a=PCahrN`+1ij5P@|&fr literal 0 HcmV?d00001 diff --git a/tests/datasets/archives/MockDataset.tar.gz b/tests/datasets/archives/MockDataset.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..cf1c3bf2ccb47cd077776b0fa9969c6d935e07a7 GIT binary patch literal 677 zcmV;W0$TkaiwFSW>;7Z_1MQd1ZrVT)$2YC2s=lOB)gF4XRH-+dU3+(pg*X^PLm~_q zoHXPTZNVrcK#180g8B+Q<c7^NP{p@oBc`4+TLMz z20WwrYkNFojkd=797F|#kYcP3XOoiA4E6E-2trI1VQiWhVx%J#A_F4>s4y~*{@|p` zStwkov;9VEhMh30*JrGcLs5Uv{|XPDJp}gva9LF?l+_(|!tfb*B!&*aN>a??e&y)z zwzPe}G`<>r%gwA(floVM% z8?agsu0>Y|uE~FdFr3PNA}Rn-`wREv67G2Z8&GZWMk`1nD&A0P#Z5!Q79j*%6e|eD z)!atPE;-qa?TprEoQK+Fxt4U(Hg^54o?ox@3hSkV-2?73quJZ_yCk2k^xQ!+<2uO6 zA8u!IzIW;ks@(Hyue=5%-oWk^d>0mcZ;O@ys360b71OY`BvYQV zvl+f_JMH}xeX~_8k>XapxEb0DLP3yR730W(MU^UrTC;+kD!55lF?zdW7yXWV#3$Yo z>ndlbH<&KD%HrNsNfboi{axSxi4JhO|3g%l{ogVe{r|V^gyh!H_IS|c(H4L?^q)?M z%>Shz_1}&8PxW79b>QauPh3-}|CWK&e^UQR{daeXzWbX?|A}?k|1ATl|0EKLuvs%7-H=CJnG%n-EWb z(NaQRaL(WN&!r!NdvG!$*L#lDXQB6)!;fY%Yk5U>?pkWC*LMxFZE zgJx$u;gH>xIQFxb8%dD}?T`!j02KT=!fAZZYnrW|4T6lC9K>Bn5Vn78BmlqkY09M4 z<51wirM=WxBX&`;Xn8>kgk`KuKK!egR=}jtZ(e`I*9ScgwNvf_c5+`V;XRPWCiME- zKB3~P!RGk5vs#e*8`V|sAv@KTju+MZcTNbyO`Z2=HIO?_?8#dRz0mah6ZLs2vI;KZ zDhfVVcwa&kVAqU%{v{n$X<4Z0zLHWw3u1H0=&HNLy~0YIZL~wxkK5ss69QzDWm>g> zsS7lN(FsrXP10f24a6;2EFPK(p>}-kCp$8H6AO11j?jBtbCJ_F0OJT?-%=3bm;5R? zZ5$KhTNJmIYQJ@pQka;;FT5s|K4rSWt+ef`s>##sJ}a;{aSnJDlOh zza3?pT%x9Pb)Fg7S&(>#lzcjN1iD6%$TpW1-@}Wkag*rw&1K0}-B6UycAklW?ET=< zkb1-YZ|IMnYXrd!GUHYJiytzy(=3P^(;#4Wo(e>?ELpG;#1v~&>}oUBY>-vAh3AUj zo`27t+TYh`&5l6QNO^({16b8I&{&NOlW@fVl(qm<5kU?-{;7pUr3c?K?eahz`*pDd u-GB1@z&Zc`000073l#DJ+aGBF0qq2UKmY)`LCgWM#Ao{g000001X)_-N;dic literal 0 HcmV?d00001 diff --git a/tests/datasets/archives/MockDataset.zip b/tests/datasets/archives/MockDataset.zip new file mode 100644 index 0000000000000000000000000000000000000000..d660da1c3a5cdef68146056294e614bbc92ed09c GIT binary patch literal 1469 zcmWIWW@Zs#U|`^2coeN2&Qn)^gAK@I0gEUwl%y7y=#^BIgof}kuzOtim*NS;AiA`I zn}Lz#D4+W zXmWZ&Qo;vcpRf=7VI2VtOcNxS)di#t4GeBEWh%@!oZqPNKtf1}C9$#e4&$Sf8D29M zK4RkG2|D?6XH_RFtL4fn%g&>!OOLM9)a+DUd5ZC1z}EbR8_kW4@(cCM4K3M|_HjVmY6F&3 z4zfXx8G8f@GDv_e02&Pp@7GKW*pde;Bzd4E4SdE5pcre-jASf4mEba!6;v)WumPbg L&^cW|y$lQh&@9O~ literal 0 HcmV?d00001 diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py new file mode 100644 index 00000000..3a9e471e --- /dev/null +++ b/tests/datasets/test_download_and_caching.py @@ -0,0 +1,220 @@ +import http.server +import threading +import unittest + +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from llmebench.datasets.dataset_base import DatasetBase + + +class ArchiveHandler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, directory="tests/datasets/archives") + + +class SignalingHTTPServer(http.server.HTTPServer): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.ready_event = threading.Event() + + def service_actions(self): + self.ready_event.set() + + +class MockDataset(DatasetBase): + def metadata(self): + return {} + + def get_data_sample(self): + return {"input": "input", "label": "label"} + + def load_data(self, data_path): + return [self.get_data_sample() for _ in range(100)] + + +class MockDatasetWithDownloadURL(MockDataset): + def __init__(self, port, filename="MockDataset.zip", **kwargs): + self.port = port + self.filename = filename + super(MockDatasetWithDownloadURL, self).__init__(**kwargs) + + def metadata(self): + return {"download_url": f"http://localhost:{self.port}/{self.filename}"} + + +class TestDatasetAutoDownload(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.httpd = SignalingHTTPServer(("", 0), ArchiveHandler) + cls.port = cls.httpd.server_address[1] + + cls.test_server = threading.Thread(target=cls.httpd.serve_forever, daemon=True) + cls.test_server.start() + cls.httpd.ready_event.wait() + + @classmethod + def tearDownClass(cls): + if cls.httpd: + cls.httpd.shutdown() + cls.httpd.server_close() + cls.test_server.join() + + def check_downloaded(self, data_dir, dataset_name, extension): + downloaded_files = list(data_dir.iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn(f"{dataset_name}.{extension}", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn(f"{dataset_name}", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + + def test_auto_download_zip(self): + "Test automatic downloading and extraction of *.zip datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.zip" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") + + def test_auto_download_tar(self): + "Test automatic downloading and extraction of *.tar datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar") + + def test_auto_download_tar_gz(self): + "Test automatic downloading and extraction of *.tar.gz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.gz" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.gz") + + def test_auto_download_tar_bz2(self): + "Test automatic downloading and extraction of *.tar.bz2 datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.bz2" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.bz2") + + def test_auto_download_tar_xz(self): + "Test automatic downloading and extraction of *.tar.xz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.xz" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.xz") + + def test_auto_download_default_url(self): + "Test automatic downloading when download url is not provided" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + with patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": f"http://localhost:{self.port}/", + }, + ): + self.assertTrue(dataset.download_dataset()) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") + + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", + }, + ) + def test_auto_download_metadata_url(self): + "Test automatic downloading when download url is provided in metadata" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name, port=self.port) + self.assertTrue(dataset.download_dataset()) + + self.check_downloaded(Path(data_dir.name), "MockDatasetWithDownloadURL", "zip") + + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", + }, + ) + def test_auto_download_non_existent(self): + "Test automatic downloading when dataset is not actually available" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL( + data_dir=data_dir.name, port=self.port, filename="InvalidDataset.zip" + ) + self.assertFalse( + dataset.download_dataset( + download_url="http://invalid.llmebench-server.org/Dataset.zip" + ) + ) + + +class TestDatasetCaching(unittest.TestCase): + def test_cache_existing_file(self): + "Test if an existing file _does not_ trigger a download" + + data_dir = TemporaryDirectory() + + # Copy a archive to the download location + archive_file = Path("tests/datasets/archives/MockDataset.zip") + copy_archive_file = Path(data_dir.name) / "MockDataset.zip" + copy_archive_file.write_bytes(archive_file.read_bytes()) + + # download_dataset should not reach out to the invalid server, + # since file is present locally + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url="http://invalid.llmebench-server.org/ExistingData.zip" + ) + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 7d86fc4f..800866c5 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -17,7 +17,7 @@ class MockDataset(DatasetBase): - def metadata(): + def metadata(self): return {} def get_data_sample(self):