From ce1650e78b86bdbd055a2c8b19875f94bfca80e0 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sun, 10 Sep 2023 10:48:22 +0300 Subject: [PATCH 01/10] Initial auto-download implementation --- llmebench/datasets/dataset_base.py | 116 ++++++++++++++++++++++++++++- setup.cfg | 1 + 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 0e88e0d6..97d61b5a 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -8,6 +8,8 @@ from langchain.prompts.example_selector import MaxMarginalRelevanceExampleSelector from langchain.vectorstores import FAISS +from pooch import Decompress, Pooch, retrieve, UnTar, Unzip + class DatasetBase(ABC): """ @@ -47,7 +49,7 @@ class DatasetBase(ABC): """ def __init__(self, **kwargs): - pass + self.data_dir = kwargs.get("data_dir", "data") @staticmethod @abstractmethod @@ -118,6 +120,118 @@ def load_data(self, data_path, no_labels=False): """ pass + def download_dataset(self, download_url=None): + def decompress(fname, action, pup): + """ + Post-processing hook to automatically detect the type of archive and + call the correct processor (UnZip, UnTar, Decompress) + + Parameters + ---------- + fname : str + Full path of the zipped file in local storage + action : str + One of "download" (file doesn't exist and will download), + "update" (file is outdated and will download), and + "fetch" (file exists and is updated so no download). + pup : Pooch + The instance of Pooch that called the processor function. + + Returns + ------- + fnames : list + List of all extracted files + + """ + if fname.endswith(".tar.xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = UnTar(extract_dir=self.__name__) + fnames = extractor(fname[:-3], action, pup) + elif fname.endswith(".tar.bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + + extractor = UnTar(extract_dir=self.__name__) + fnames = extractor(fname[:-4], action, pup) + elif fname.endswith(".tar.gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = UnTar(extract_dir=self.__name__) + fnames = extractor(fname[:-3], action, pup) + elif fname.endswith(".xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".tar"): + extractor = UnTar(extract_dir=self.__name__) + fnames = extractor(fname, action, pup) + elif fname.endswith(".zip"): + extractor = Unzip(extract_dir=self.__name__) + fnames = extractor(fname, action, pup) + + return fnames + + # Priority: + # Fn Argument + # Dataset metadata["download_url"] + # BASE_DOWNLOAD_URL/Dataset_name + download_urls = [] + if download_url is not None: + download_urls.append(download_url) + + metadata_url = self.metadata().get("download_url", None) + if metadata_url is not None: + download_urls.append(metadata_url) + + default_url = os.getenv("DEFAULT_DOWNLOAD_URL") + if default_url is not None: + if default_url.endswith("/"): + default_url = default_url[:-1] + default_url = f"{default_url}/{self.__name__}.zip" + download_urls.append(default_url) + + # Try downloading from available links in order of priority + for download_url in download_urls: + if not Pooch.is_available(download_url): + continue + + extension = ".zip" + supported_extensions = [ + ".tar.xz", + ".tar.bz2", + ".tar.gz", + ".xz", + ".bz2", + ".gz", + ".tar", + ".zip", + ] + + for ext in supported_extensions: + if download_urls.endswith(ext): + extension = ext + break + + pooch.retrieve( + download_url, + known_hash=None, + fname=f"{self.__name__}.{extension}", + path=self.data_dir, + progressbar=True, + processor=decompress, + ) + def _deduplicate_train_test(self, train_data, test_data): """ Filter train data to avoid overlap with test data diff --git a/setup.cfg b/setup.cfg index 6df97ac4..2e33d157 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ install_requires = nltk==3.8.1 openai==0.27.7 pandas==2.0.2 + pooch==1.7.0 python-dotenv==1.0.0 scikit-learn==1.2.2 tenacity==8.2.2 From e06ef39cf12bd78a284a6cfbce25285cb2f7bc9a Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 11:50:39 +0300 Subject: [PATCH 02/10] Improve auto-downloader implementation --- llmebench/datasets/dataset_base.py | 64 ++++++++++++++++++------------ 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 97d61b5a..85f1cc3e 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -1,14 +1,16 @@ import json import logging +import os import random from abc import ABC, abstractmethod +from pathlib import Path from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts.example_selector import MaxMarginalRelevanceExampleSelector from langchain.vectorstores import FAISS -from pooch import Decompress, Pooch, retrieve, UnTar, Unzip +from pooch import Decompress, Pooch, retrieve, Untar, Unzip class DatasetBase(ABC): @@ -51,9 +53,8 @@ class DatasetBase(ABC): def __init__(self, **kwargs): self.data_dir = kwargs.get("data_dir", "data") - @staticmethod @abstractmethod - def metadata(): + def metadata(self): """ Returns the dataset's metadata @@ -124,7 +125,7 @@ def download_dataset(self, download_url=None): def decompress(fname, action, pup): """ Post-processing hook to automatically detect the type of archive and - call the correct processor (UnZip, UnTar, Decompress) + call the correct processor (UnZip, Untar, Decompress) Parameters ---------- @@ -143,24 +144,26 @@ def decompress(fname, action, pup): List of all extracted files """ + # Remove intermediate tar file + extract_dir = self.__class__.__name__ if fname.endswith(".tar.xz"): extractor = Decompress(name=fname[:-3]) fname = extractor(fname, action, pup) - extractor = UnTar(extract_dir=self.__name__) - fnames = extractor(fname[:-3], action, pup) + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) elif fname.endswith(".tar.bz2"): extractor = Decompress(name=fname[:-4]) fname = extractor(fname, action, pup) - extractor = UnTar(extract_dir=self.__name__) - fnames = extractor(fname[:-4], action, pup) + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) elif fname.endswith(".tar.gz"): extractor = Decompress(name=fname[:-3]) fname = extractor(fname, action, pup) - extractor = UnTar(extract_dir=self.__name__) - fnames = extractor(fname[:-3], action, pup) + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) elif fname.endswith(".xz"): extractor = Decompress(name=fname[:-3]) fname = extractor(fname, action, pup) @@ -174,10 +177,10 @@ def decompress(fname, action, pup): fname = extractor(fname, action, pup) fnames = [fname] elif fname.endswith(".tar"): - extractor = UnTar(extract_dir=self.__name__) + extractor = Untar(extract_dir=extract_dir) fnames = extractor(fname, action, pup) elif fname.endswith(".zip"): - extractor = Unzip(extract_dir=self.__name__) + extractor = Unzip(extract_dir=extract_dir) fnames = extractor(fname, action, pup) return fnames @@ -198,14 +201,11 @@ def decompress(fname, action, pup): if default_url is not None: if default_url.endswith("/"): default_url = default_url[:-1] - default_url = f"{default_url}/{self.__name__}.zip" + default_url = f"{default_url}/{self.__class__.__name__}.zip" download_urls.append(default_url) # Try downloading from available links in order of priority for download_url in download_urls: - if not Pooch.is_available(download_url): - continue - extension = ".zip" supported_extensions = [ ".tar.xz", @@ -219,18 +219,32 @@ def decompress(fname, action, pup): ] for ext in supported_extensions: - if download_urls.endswith(ext): + if download_url.endswith(ext): extension = ext break + try: + print(f"trying {download_url}") + retrieve( + download_url, + known_hash=None, + fname=f"{self.__class__.__name__}{extension}", + path=self.data_dir, + progressbar=True, + processor=decompress, + ) + # If it was a *.tar.* file, we can safely delete the + # intermediate *.tar file + if extension in supported_extensions[:3]: + tar_file_path = ( + Path(self.data_dir) / f"{self.__class__.__name__}.tar" + ) + tar_file_path.unlink() + print(f"succeeded") + break + except Exception as e: + print(f"issue {e}") - pooch.retrieve( - download_url, - known_hash=None, - fname=f"{self.__name__}.{extension}", - path=self.data_dir, - progressbar=True, - processor=decompress, - ) + continue def _deduplicate_train_test(self, train_data, test_data): """ From 5c6920ed375b61fc452b058389f4980c5be38c06 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 11:51:36 +0300 Subject: [PATCH 03/10] Add tests for auto downloader --- tests/datasets/archives/MockDataset.tar | Bin 0 -> 8192 bytes tests/datasets/archives/MockDataset.tar.gz | Bin 0 -> 677 bytes tests/datasets/archives/MockDataset.zip | Bin 0 -> 1469 bytes tests/datasets/test_caching.py | 146 +++++++++++++++++++++ tests/test_benchmark.py | 2 +- 5 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 tests/datasets/archives/MockDataset.tar create mode 100644 tests/datasets/archives/MockDataset.tar.gz create mode 100644 tests/datasets/archives/MockDataset.zip create mode 100644 tests/datasets/test_caching.py diff --git a/tests/datasets/archives/MockDataset.tar b/tests/datasets/archives/MockDataset.tar new file mode 100644 index 0000000000000000000000000000000000000000..18f6f23c81bae5fc818f2a284b5e52e506e8aff8 GIT binary patch literal 8192 zcmeHM&uSV$7@xF+(!Hcm+CxtZf!^HA?9S}Eh=+Bx`9oG+-J~&>Fx8cqs4?pVjr0|I z%Aqe%=o926dMotU7bx^l`UIWTfMGGFN>-9~2IQN6-(TnZ_?IT}A_LTAdy?TGtOEbAGK} zkL+Dv4n2V*=OCvKWh}pZ0v-Y2(5hT0%X{+G2=~BK{!sx~N^{EBODBKll;wG)(aj;p zcRs&}a6abiaQ_(quc{|4g?2hkUun`VJL~$j8fE@!v3$kzlhX|VR>t>QwAvgVE@4hm z0K6EVou1KvwppX?gRdWc)xvAfKYdArbQyenefe4HTR$#W)-S%j|9u?p&OT1pAWWi^ zj_WX4tJYKHj)ix5Ak;^1Z8$wgVCf0~PyRNu7inPDs)0!Ww z#a9P-YD~+2gfNWcKj9Ss$o;sT59R;o`L98_#TqR?jdcD(bu(#b3N{HL*wnFvP*UFB z%vdFRV{yCzmAi2rMp#gVQswCZ<=4_u)I?f3jO(kO9 zWNcJGwB$;QNMWI4{s_cZ^~W~dQvbsaFjD_Dq^qL;8%Hp{+3@?1PrQ(P&m*85Fu%)U zg?}OLo7VqL!hb3uBL5fibOipJIn4hk{)?>++>-wj@E`vzNEi5T=7%{4gTQ|gOe^r8 tz<+b((L-g4zxxBX_J0%bAOC+q^nZ(ZQv&}jBFaLCh=7QIh`{}ez#oY&Vsiih literal 0 HcmV?d00001 diff --git a/tests/datasets/archives/MockDataset.tar.gz b/tests/datasets/archives/MockDataset.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..be9e3687fbe7dd64f685c0b3960e6d3eb8bfcfc7 GIT binary patch literal 677 zcmV;W0$TkaiwFQ))Ba=t1MQd1ZrVT)$2YC2s=lOB)gF4XRH-+dU3+(pg*X^PLm~_q zoHXPTZNVrcK#180g8B+Q<c7^NP{p@oBc`4+TLMz z20WwrYkNFojkd=797F|#kYcP3XOoiA4E6E-2trI1VQiWhVx%J#A_F4>s4y~*{@|p` zStwkov;9VEhMh30*JrGcLs5Uv{|XPDJp}gva9LF?l+_(|!tfb*B!&*aN>a??e&y)z zwzPe}G`<>r%gwA(floVM% z8?agsu0>Y|uE~FdFr3PNA}Rn-`wREv67G2Z8&GZWMk`1nD&A0P#Z5!Q79j*%6e|eD z)!atPE;-qa?TprEoQK+Fxt4U(Hg^54o?ox@3hSkV-2?73quJZ_yCk2k^xQ!+<2uO6 zA8u!IzIW;ks@(Hyue=5%-oWk^d>0mcZ;O@ys360b71OY`BvYQV zvl+f_JMH}xeX~_8k>XapxEb0DLP3yR730W(MU^UrTC;+kD!55lF?zdW7yXWV#3$Yo z>ndlbH<&KD%HrNsNfboi{axSxi4JhO|3g%l{ogVe{r|V^gyh!H_IS|c(H4L?^q)?M z%>Shz_1}&8PxW79b>QauPh3-}|CWK&e^UQR{daeXzWbX?|A}?k|1ATl|0EKL4+W zXmWZ&Qo;vcpRf=7VI2VtOcNxS)di#t4GeBEWh%@!oZqPNKtf1}C9$#e4&$Sf8D29M zK4RkG2|D?6XH_RFtL4fn%g&>!OOLM9)a+DUd5ZC1z}EbR8_kW4@(cCM4K3M|_HjVmY6F&3 z4zfXx8G8f@GDv_e02&Pp@7GKW*pde;Bzd4E4SdE5pcre-jASf4mEba!6;v)WumPbg L&^cW|y$lQh&@9O~ literal 0 HcmV?d00001 diff --git a/tests/datasets/test_caching.py b/tests/datasets/test_caching.py new file mode 100644 index 00000000..41dabfd3 --- /dev/null +++ b/tests/datasets/test_caching.py @@ -0,0 +1,146 @@ +import http.server +import threading +import unittest + +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from llmebench.datasets.dataset_base import DatasetBase + + +class ArchiveHandler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, directory="tests/datasets/archives") + + +class SignalingHTTPServer(http.server.HTTPServer): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.ready_event = threading.Event() + + def service_actions(self): + self.ready_event.set() + + +class MockDataset(DatasetBase): + def metadata(self): + return {} + + def get_data_sample(self): + return {"input": "input", "label": "label"} + + def load_data(self, data_path): + return [self.get_data_sample() for _ in range(100)] + + +class TestDatasetAutoDownload(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.httpd = SignalingHTTPServer(("", 8076), ArchiveHandler) + cls.test_server = threading.Thread(target=cls.httpd.serve_forever, daemon=True) + cls.test_server.start() + cls.httpd.ready_event.wait() + + @classmethod + def tearDownClass(cls): + if cls.httpd: + cls.httpd.shutdown() + cls.httpd.server_close() + cls.test_server.join() + + def test_auto_download_zip(self): + "Test automatic downloading and extraction of *.zip datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + dataset.download_dataset(download_url="http://localhost:8076/MockDataset.zip") + + downloaded_files = list(Path(data_dir.name).iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn("MockDataset.zip", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn("MockDataset", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + + def test_auto_download_tar(self): + "Test automatic downloading and extraction of *.tar datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + dataset.download_dataset(download_url="http://localhost:8076/MockDataset.tar") + + downloaded_files = list(Path(data_dir.name).iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn("MockDataset.tar", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn("MockDataset", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + + def test_auto_download_tar_gz(self): + "Test automatic downloading and extraction of *.tar.gz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + dataset.download_dataset( + download_url="http://localhost:8076/MockDataset.tar.gz" + ) + + downloaded_files = list(Path(data_dir.name).iterdir()) + self.assertEqual(len(downloaded_files), 2) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertIn("MockDataset.tar.gz", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn("MockDataset", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://localhost:8076/", + }, + ) + def test_auto_download_default_url(self): + "Test automatic downloading when download url is not provided" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + dataset.download_dataset() + + downloaded_files = list(Path(data_dir.name).iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn("MockDataset.zip", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn("MockDataset", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 7d86fc4f..800866c5 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -17,7 +17,7 @@ class MockDataset(DatasetBase): - def metadata(): + def metadata(self): return {} def get_data_sample(self): From 34d293807bd0a2e81fc1f25cccd34d00c6414e94 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:10:31 +0300 Subject: [PATCH 04/10] Add tests for caching mechanism --- llmebench/datasets/dataset_base.py | 10 +++-- ...aching.py => test_download_and_caching.py} | 37 ++++++++++++++++--- 2 files changed, 39 insertions(+), 8 deletions(-) rename tests/datasets/{test_caching.py => test_download_and_caching.py} (82%) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 85f1cc3e..38b8b169 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -144,8 +144,11 @@ def decompress(fname, action, pup): List of all extracted files """ - # Remove intermediate tar file + # Default where the downloaded file is not a container/archive + fnames = [fname] + extract_dir = self.__class__.__name__ + if fname.endswith(".tar.xz"): extractor = Decompress(name=fname[:-3]) fname = extractor(fname, action, pup) @@ -239,13 +242,14 @@ def decompress(fname, action, pup): Path(self.data_dir) / f"{self.__class__.__name__}.tar" ) tar_file_path.unlink() - print(f"succeeded") - break + return True except Exception as e: print(f"issue {e}") continue + return False + def _deduplicate_train_test(self, train_data, test_data): """ Filter train data to avoid overlap with test data diff --git a/tests/datasets/test_caching.py b/tests/datasets/test_download_and_caching.py similarity index 82% rename from tests/datasets/test_caching.py rename to tests/datasets/test_download_and_caching.py index 41dabfd3..8d661d17 100644 --- a/tests/datasets/test_caching.py +++ b/tests/datasets/test_download_and_caching.py @@ -55,7 +55,11 @@ def test_auto_download_zip(self): data_dir = TemporaryDirectory() dataset = MockDataset(data_dir=data_dir.name) - dataset.download_dataset(download_url="http://localhost:8076/MockDataset.zip") + self.assertTrue( + dataset.download_dataset( + download_url="http://localhost:8076/MockDataset.zip" + ) + ) downloaded_files = list(Path(data_dir.name).iterdir()) downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] @@ -77,7 +81,11 @@ def test_auto_download_tar(self): data_dir = TemporaryDirectory() dataset = MockDataset(data_dir=data_dir.name) - dataset.download_dataset(download_url="http://localhost:8076/MockDataset.tar") + self.assertTrue( + dataset.download_dataset( + download_url="http://localhost:8076/MockDataset.tar" + ) + ) downloaded_files = list(Path(data_dir.name).iterdir()) downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] @@ -99,8 +107,10 @@ def test_auto_download_tar_gz(self): data_dir = TemporaryDirectory() dataset = MockDataset(data_dir=data_dir.name) - dataset.download_dataset( - download_url="http://localhost:8076/MockDataset.tar.gz" + self.assertTrue( + dataset.download_dataset( + download_url="http://localhost:8076/MockDataset.tar.gz" + ) ) downloaded_files = list(Path(data_dir.name).iterdir()) @@ -129,7 +139,7 @@ def test_auto_download_default_url(self): data_dir = TemporaryDirectory() dataset = MockDataset(data_dir=data_dir.name) - dataset.download_dataset() + self.assertTrue(dataset.download_dataset()) downloaded_files = list(Path(data_dir.name).iterdir()) downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] @@ -144,3 +154,20 @@ def test_auto_download_default_url(self): dataset_files = [f.name for f in extracted_directories[0].iterdir()] self.assertIn("train.txt", dataset_files) self.assertIn("test.txt", dataset_files) + + +class TestDatasetCaching(unittest.TestCase): + def test_cache_existing_file(self): + "Test if an existing file _does not_ trigger a download" + + data_dir = TemporaryDirectory() + archive_file = Path("tests/datasets/archives/MockDataset.zip") + copy_archive_file = Path(data_dir.name) / "MockDataset.zip" + copy_archive_file.write_bytes(archive_file.read_bytes()) + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url="http://localhost:8076/ExistingData.zip" + ) + ) From 4077229279592d15a8544015d9268c64cea35573 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:13:30 +0300 Subject: [PATCH 05/10] Add test for metadata download url usage over environment variable --- tests/datasets/test_download_and_caching.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py index 8d661d17..ae9f0a5b 100644 --- a/tests/datasets/test_download_and_caching.py +++ b/tests/datasets/test_download_and_caching.py @@ -34,6 +34,11 @@ def load_data(self, data_path): return [self.get_data_sample() for _ in range(100)] +class MockDatasetWithDownloadURL(MockDataset): + def metadata(self): + return {"download_url": "http://localhost:8076/MockDataset.zip"} + + class TestDatasetAutoDownload(unittest.TestCase): @classmethod def setUpClass(cls): @@ -155,6 +160,34 @@ def test_auto_download_default_url(self): self.assertIn("train.txt", dataset_files) self.assertIn("test.txt", dataset_files) + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.com", + }, + ) + def test_auto_download_metadata_url(self): + "Test automatic downloading when download url is provided in metadata" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name) + self.assertTrue(dataset.download_dataset()) + + downloaded_files = list(Path(data_dir.name).iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn("MockDatasetWithDownloadURL.zip", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn("MockDatasetWithDownloadURL", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + class TestDatasetCaching(unittest.TestCase): def test_cache_existing_file(self): From cf25f3b71e5bbaf70b28bf840882a2eac334c9d3 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:25:17 +0300 Subject: [PATCH 06/10] Generalize port for test server --- tests/datasets/test_download_and_caching.py | 40 +++++++++++++-------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py index ae9f0a5b..2150b2fe 100644 --- a/tests/datasets/test_download_and_caching.py +++ b/tests/datasets/test_download_and_caching.py @@ -35,14 +35,20 @@ def load_data(self, data_path): class MockDatasetWithDownloadURL(MockDataset): + def __init__(self, port, **kwargs): + self.port = port + super(MockDatasetWithDownloadURL, self).__init__(**kwargs) + def metadata(self): - return {"download_url": "http://localhost:8076/MockDataset.zip"} + return {"download_url": f"http://localhost:{self.port}/MockDataset.zip"} class TestDatasetAutoDownload(unittest.TestCase): @classmethod def setUpClass(cls): - cls.httpd = SignalingHTTPServer(("", 8076), ArchiveHandler) + cls.httpd = SignalingHTTPServer(("", 0), ArchiveHandler) + cls.port = cls.httpd.server_address[1] + cls.test_server = threading.Thread(target=cls.httpd.serve_forever, daemon=True) cls.test_server.start() cls.httpd.ready_event.wait() @@ -62,7 +68,7 @@ def test_auto_download_zip(self): dataset = MockDataset(data_dir=data_dir.name) self.assertTrue( dataset.download_dataset( - download_url="http://localhost:8076/MockDataset.zip" + download_url=f"http://localhost:{self.port}/MockDataset.zip" ) ) @@ -88,7 +94,7 @@ def test_auto_download_tar(self): dataset = MockDataset(data_dir=data_dir.name) self.assertTrue( dataset.download_dataset( - download_url="http://localhost:8076/MockDataset.tar" + download_url=f"http://localhost:{self.port}/MockDataset.tar" ) ) @@ -114,7 +120,7 @@ def test_auto_download_tar_gz(self): dataset = MockDataset(data_dir=data_dir.name) self.assertTrue( dataset.download_dataset( - download_url="http://localhost:8076/MockDataset.tar.gz" + download_url=f"http://localhost:{self.port}/MockDataset.tar.gz" ) ) @@ -132,19 +138,19 @@ def test_auto_download_tar_gz(self): self.assertIn("train.txt", dataset_files) self.assertIn("test.txt", dataset_files) - @patch.dict( - "os.environ", - { - "DEFAULT_DOWNLOAD_URL": "http://localhost:8076/", - }, - ) def test_auto_download_default_url(self): "Test automatic downloading when download url is not provided" data_dir = TemporaryDirectory() dataset = MockDataset(data_dir=data_dir.name) - self.assertTrue(dataset.download_dataset()) + with patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": f"http://localhost:{self.port}/", + }, + ): + self.assertTrue(dataset.download_dataset()) downloaded_files = list(Path(data_dir.name).iterdir()) downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] @@ -163,7 +169,7 @@ def test_auto_download_default_url(self): @patch.dict( "os.environ", { - "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.com", + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", }, ) def test_auto_download_metadata_url(self): @@ -171,7 +177,7 @@ def test_auto_download_metadata_url(self): data_dir = TemporaryDirectory() - dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name) + dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name, port=self.port) self.assertTrue(dataset.download_dataset()) downloaded_files = list(Path(data_dir.name).iterdir()) @@ -194,13 +200,17 @@ def test_cache_existing_file(self): "Test if an existing file _does not_ trigger a download" data_dir = TemporaryDirectory() + + # Copy a archive to the download location archive_file = Path("tests/datasets/archives/MockDataset.zip") copy_archive_file = Path(data_dir.name) / "MockDataset.zip" copy_archive_file.write_bytes(archive_file.read_bytes()) + # download_dataset should not reach out to the invalid server, + # since file is present locally dataset = MockDataset(data_dir=data_dir.name) self.assertTrue( dataset.download_dataset( - download_url="http://localhost:8076/ExistingData.zip" + download_url="http://invalid.llmebench-server.org/ExistingData.zip" ) ) From 3120394d41eb8cab3114d541d1ce4263fdae13d3 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:39:16 +0300 Subject: [PATCH 07/10] Add tests for tar.bz2 and tar.xz files --- tests/datasets/archives/MockDataset.tar.bz2 | Bin 0 -> 869 bytes tests/datasets/archives/MockDataset.tar.gz | Bin 677 -> 677 bytes tests/datasets/archives/MockDataset.tar.xz | Bin 0 -> 660 bytes tests/datasets/test_download_and_caching.py | 109 +++++++++----------- 4 files changed, 46 insertions(+), 63 deletions(-) create mode 100644 tests/datasets/archives/MockDataset.tar.bz2 create mode 100644 tests/datasets/archives/MockDataset.tar.xz diff --git a/tests/datasets/archives/MockDataset.tar.bz2 b/tests/datasets/archives/MockDataset.tar.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..a1e0ff88f73afe5221c0bc66886a72c2db53602b GIT binary patch literal 869 zcmV-r1DgCoT4*^jL0KkKSp`d?FaQBHfB*lu%aBULm+9O00STcATj_P01XC)gFptE28KY>KmZ0n20&y03Y3Uxpp6qI z6H^*wj}+d5PejS5lo}c|#M+t+38n-#r!?)I*Us!No@EZ=cay;Ho+SPoCgYW^l+ZF$x8m#3+%D$;5X++^rn<_-Oq=@J!=_o?XTpSuT-Dl{D{3T|R zP)(A9kwEFFW~e@O;!(7zTGdTeZ(6ohXK9+q!sa{AVPo1+8R{gwoYbOPOw{l$pF&=C znH)`56ZTfT!roZGbCcG4jU((c3k=6Cya;QP3K@6lN>lD@6j4W=d8lF&n8wIf6B#(t ziQeyizpp>Vc`uXlZM&t!%}mXsB5XXY0mM+q!l^eCwCpIhOOmHftPE*W9jlF_BO@gj z63UXyC!t#H#>3vOO|rzkd0Orkh74+w#J;tKFeRUPbuXplsIe-#cp476!m1~6fn~-j zFg}I?wwpostz2ZIG{b>)IE%?@>jQZ$X_|{9wJ^z^TCBvOp0!blVR4M8si0;Wj0RJn zX}<+;nMFz}WH%TJhU>=4tu)iV^}B~uwDdn@+G`&xWX?{KX*`a*IgGTm!r@Jmpj78O z?;7*;HQQv(&zbXQzoTr@cH!*66pI?n4U! vEltN6MCiT;!1_)yi2n*a=PCahrN`+1ij5P@|&fr literal 0 HcmV?d00001 diff --git a/tests/datasets/archives/MockDataset.tar.gz b/tests/datasets/archives/MockDataset.tar.gz index be9e3687fbe7dd64f685c0b3960e6d3eb8bfcfc7..cf1c3bf2ccb47cd077776b0fa9969c6d935e07a7 100644 GIT binary patch delta 15 WcmZ3=x|EeozMF&L;p>fT^O*oAD+MtC delta 15 WcmZ3=x|EeozMF#~`tnA$`Ah&Kj|7td diff --git a/tests/datasets/archives/MockDataset.tar.xz b/tests/datasets/archives/MockDataset.tar.xz new file mode 100644 index 0000000000000000000000000000000000000000..9f82e862cbf6bedbc07e51eb78f9774fe8acef90 GIT binary patch literal 660 zcmV;F0&D&KH+ooF000E$*0e?f03iVu0001VFXf})AO8YTT>uvs%7-H=CJnG%n-EWb z(NaQRaL(WN&!r!NdvG!$*L#lDXQB6)!;fY%Yk5U>?pkWC*LMxFZE zgJx$u;gH>xIQFxb8%dD}?T`!j02KT=!fAZZYnrW|4T6lC9K>Bn5Vn78BmlqkY09M4 z<51wirM=WxBX&`;Xn8>kgk`KuKK!egR=}jtZ(e`I*9ScgwNvf_c5+`V;XRPWCiME- zKB3~P!RGk5vs#e*8`V|sAv@KTju+MZcTNbyO`Z2=HIO?_?8#dRz0mah6ZLs2vI;KZ zDhfVVcwa&kVAqU%{v{n$X<4Z0zLHWw3u1H0=&HNLy~0YIZL~wxkK5ss69QzDWm>g> zsS7lN(FsrXP10f24a6;2EFPK(p>}-kCp$8H6AO11j?jBtbCJ_F0OJT?-%=3bm;5R? zZ5$KhTNJmIYQJ@pQka;;FT5s|K4rSWt+ef`s>##sJ}a;{aSnJDlOh zza3?pT%x9Pb)Fg7S&(>#lzcjN1iD6%$TpW1-@}Wkag*rw&1K0}-B6UycAklW?ET=< zkb1-YZ|IMnYXrd!GUHYJiytzy(=3P^(;#4Wo(e>?ELpG;#1v~&>}oUBY>-vAh3AUj zo`27t+TYh`&5l6QNO^({16b8I&{&NOlW@fVl(qm<5kU?-{;7pUr3c?K?eahz`*pDd u-GB1@z&Zc`000073l#DJ+aGBF0qq2UKmY)`LCgWM#Ao{g000001X)_-N;dic literal 0 HcmV?d00001 diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py index 2150b2fe..dde20f19 100644 --- a/tests/datasets/test_download_and_caching.py +++ b/tests/datasets/test_download_and_caching.py @@ -60,6 +60,21 @@ def tearDownClass(cls): cls.httpd.server_close() cls.test_server.join() + def check_downloaded(self, data_dir, dataset_name, extension): + downloaded_files = list(data_dir.iterdir()) + downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] + self.assertEqual(len(downloaded_files), 2) + self.assertIn(f"{dataset_name}.{extension}", downloaded_filenames) + + extracted_directories = [d for d in downloaded_files if d.is_dir()] + extracted_directory_names = [d.name for d in extracted_directories] + self.assertIn(f"{dataset_name}", extracted_directory_names) + self.assertEqual(len(extracted_directory_names), 1) + + dataset_files = [f.name for f in extracted_directories[0].iterdir()] + self.assertIn("train.txt", dataset_files) + self.assertIn("test.txt", dataset_files) + def test_auto_download_zip(self): "Test automatic downloading and extraction of *.zip datasets" @@ -72,19 +87,7 @@ def test_auto_download_zip(self): ) ) - downloaded_files = list(Path(data_dir.name).iterdir()) - downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] - self.assertEqual(len(downloaded_files), 2) - self.assertIn("MockDataset.zip", downloaded_filenames) - - extracted_directories = [d for d in downloaded_files if d.is_dir()] - extracted_directory_names = [d.name for d in extracted_directories] - self.assertIn("MockDataset", extracted_directory_names) - self.assertEqual(len(extracted_directory_names), 1) - - dataset_files = [f.name for f in extracted_directories[0].iterdir()] - self.assertIn("train.txt", dataset_files) - self.assertIn("test.txt", dataset_files) + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") def test_auto_download_tar(self): "Test automatic downloading and extraction of *.tar datasets" @@ -98,19 +101,7 @@ def test_auto_download_tar(self): ) ) - downloaded_files = list(Path(data_dir.name).iterdir()) - downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] - self.assertEqual(len(downloaded_files), 2) - self.assertIn("MockDataset.tar", downloaded_filenames) - - extracted_directories = [d for d in downloaded_files if d.is_dir()] - extracted_directory_names = [d.name for d in extracted_directories] - self.assertIn("MockDataset", extracted_directory_names) - self.assertEqual(len(extracted_directory_names), 1) - - dataset_files = [f.name for f in extracted_directories[0].iterdir()] - self.assertIn("train.txt", dataset_files) - self.assertIn("test.txt", dataset_files) + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar") def test_auto_download_tar_gz(self): "Test automatic downloading and extraction of *.tar.gz datasets" @@ -124,19 +115,35 @@ def test_auto_download_tar_gz(self): ) ) - downloaded_files = list(Path(data_dir.name).iterdir()) - self.assertEqual(len(downloaded_files), 2) - downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] - self.assertIn("MockDataset.tar.gz", downloaded_filenames) + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.gz") - extracted_directories = [d for d in downloaded_files if d.is_dir()] - extracted_directory_names = [d.name for d in extracted_directories] - self.assertIn("MockDataset", extracted_directory_names) - self.assertEqual(len(extracted_directory_names), 1) + def test_auto_download_tar_bz2(self): + "Test automatic downloading and extraction of *.tar.bz2 datasets" - dataset_files = [f.name for f in extracted_directories[0].iterdir()] - self.assertIn("train.txt", dataset_files) - self.assertIn("test.txt", dataset_files) + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.bz2" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.bz2") + + def test_auto_download_tar_xz(self): + "Test automatic downloading and extraction of *.tar.xz datasets" + + data_dir = TemporaryDirectory() + + dataset = MockDataset(data_dir=data_dir.name) + self.assertTrue( + dataset.download_dataset( + download_url=f"http://localhost:{self.port}/MockDataset.tar.xz" + ) + ) + + self.check_downloaded(Path(data_dir.name), "MockDataset", "tar.xz") def test_auto_download_default_url(self): "Test automatic downloading when download url is not provided" @@ -152,19 +159,7 @@ def test_auto_download_default_url(self): ): self.assertTrue(dataset.download_dataset()) - downloaded_files = list(Path(data_dir.name).iterdir()) - downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] - self.assertEqual(len(downloaded_files), 2) - self.assertIn("MockDataset.zip", downloaded_filenames) - - extracted_directories = [d for d in downloaded_files if d.is_dir()] - extracted_directory_names = [d.name for d in extracted_directories] - self.assertIn("MockDataset", extracted_directory_names) - self.assertEqual(len(extracted_directory_names), 1) - - dataset_files = [f.name for f in extracted_directories[0].iterdir()] - self.assertIn("train.txt", dataset_files) - self.assertIn("test.txt", dataset_files) + self.check_downloaded(Path(data_dir.name), "MockDataset", "zip") @patch.dict( "os.environ", @@ -180,19 +175,7 @@ def test_auto_download_metadata_url(self): dataset = MockDatasetWithDownloadURL(data_dir=data_dir.name, port=self.port) self.assertTrue(dataset.download_dataset()) - downloaded_files = list(Path(data_dir.name).iterdir()) - downloaded_filenames = [f.name for f in downloaded_files if f.is_file()] - self.assertEqual(len(downloaded_files), 2) - self.assertIn("MockDatasetWithDownloadURL.zip", downloaded_filenames) - - extracted_directories = [d for d in downloaded_files if d.is_dir()] - extracted_directory_names = [d.name for d in extracted_directories] - self.assertIn("MockDatasetWithDownloadURL", extracted_directory_names) - self.assertEqual(len(extracted_directory_names), 1) - - dataset_files = [f.name for f in extracted_directories[0].iterdir()] - self.assertIn("train.txt", dataset_files) - self.assertIn("test.txt", dataset_files) + self.check_downloaded(Path(data_dir.name), "MockDatasetWithDownloadURL", "zip") class TestDatasetCaching(unittest.TestCase): From 192de4dd5af95a1beeaf7c6def8b034319822c50 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:48:18 +0300 Subject: [PATCH 08/10] Add test for non-existent datasets --- llmebench/datasets/dataset_base.py | 7 +++--- tests/datasets/test_download_and_caching.py | 25 +++++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 38b8b169..c24f16ca 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -226,7 +226,7 @@ def decompress(fname, action, pup): extension = ext break try: - print(f"trying {download_url}") + logging.info(f"Trying {download_url}") retrieve( download_url, known_hash=None, @@ -244,10 +244,11 @@ def decompress(fname, action, pup): tar_file_path.unlink() return True except Exception as e: - print(f"issue {e}") - + logging.warning(f"Failed to download: {e}") continue + logging.warning(f"Failed to download dataset") + return False def _deduplicate_train_test(self, train_data, test_data): diff --git a/tests/datasets/test_download_and_caching.py b/tests/datasets/test_download_and_caching.py index dde20f19..3a9e471e 100644 --- a/tests/datasets/test_download_and_caching.py +++ b/tests/datasets/test_download_and_caching.py @@ -35,12 +35,13 @@ def load_data(self, data_path): class MockDatasetWithDownloadURL(MockDataset): - def __init__(self, port, **kwargs): + def __init__(self, port, filename="MockDataset.zip", **kwargs): self.port = port + self.filename = filename super(MockDatasetWithDownloadURL, self).__init__(**kwargs) def metadata(self): - return {"download_url": f"http://localhost:{self.port}/MockDataset.zip"} + return {"download_url": f"http://localhost:{self.port}/{self.filename}"} class TestDatasetAutoDownload(unittest.TestCase): @@ -177,6 +178,26 @@ def test_auto_download_metadata_url(self): self.check_downloaded(Path(data_dir.name), "MockDatasetWithDownloadURL", "zip") + @patch.dict( + "os.environ", + { + "DEFAULT_DOWNLOAD_URL": "http://invalid.llmebench-server.org", + }, + ) + def test_auto_download_non_existent(self): + "Test automatic downloading when dataset is not actually available" + + data_dir = TemporaryDirectory() + + dataset = MockDatasetWithDownloadURL( + data_dir=data_dir.name, port=self.port, filename="InvalidDataset.zip" + ) + self.assertFalse( + dataset.download_dataset( + download_url="http://invalid.llmebench-server.org/Dataset.zip" + ) + ) + class TestDatasetCaching(unittest.TestCase): def test_cache_existing_file(self): From f46908fd54f366d1440c8754b2be57671063820b Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 13:56:14 +0300 Subject: [PATCH 09/10] Clean up code and add docstrings --- llmebench/datasets/dataset_base.py | 278 +++++++++++++++-------------- 1 file changed, 148 insertions(+), 130 deletions(-) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index c24f16ca..adc92b19 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -121,136 +121,6 @@ def load_data(self, data_path, no_labels=False): """ pass - def download_dataset(self, download_url=None): - def decompress(fname, action, pup): - """ - Post-processing hook to automatically detect the type of archive and - call the correct processor (UnZip, Untar, Decompress) - - Parameters - ---------- - fname : str - Full path of the zipped file in local storage - action : str - One of "download" (file doesn't exist and will download), - "update" (file is outdated and will download), and - "fetch" (file exists and is updated so no download). - pup : Pooch - The instance of Pooch that called the processor function. - - Returns - ------- - fnames : list - List of all extracted files - - """ - # Default where the downloaded file is not a container/archive - fnames = [fname] - - extract_dir = self.__class__.__name__ - - if fname.endswith(".tar.xz"): - extractor = Decompress(name=fname[:-3]) - fname = extractor(fname, action, pup) - - extractor = Untar(extract_dir=extract_dir) - fnames = extractor(fname, action, pup) - elif fname.endswith(".tar.bz2"): - extractor = Decompress(name=fname[:-4]) - fname = extractor(fname, action, pup) - - extractor = Untar(extract_dir=extract_dir) - fnames = extractor(fname, action, pup) - elif fname.endswith(".tar.gz"): - extractor = Decompress(name=fname[:-3]) - fname = extractor(fname, action, pup) - - extractor = Untar(extract_dir=extract_dir) - fnames = extractor(fname, action, pup) - elif fname.endswith(".xz"): - extractor = Decompress(name=fname[:-3]) - fname = extractor(fname, action, pup) - fnames = [fname] - elif fname.endswith(".bz2"): - extractor = Decompress(name=fname[:-4]) - fname = extractor(fname, action, pup) - fnames = [fname] - elif fname.endswith(".gz"): - extractor = Decompress(name=fname[:-3]) - fname = extractor(fname, action, pup) - fnames = [fname] - elif fname.endswith(".tar"): - extractor = Untar(extract_dir=extract_dir) - fnames = extractor(fname, action, pup) - elif fname.endswith(".zip"): - extractor = Unzip(extract_dir=extract_dir) - fnames = extractor(fname, action, pup) - - return fnames - - # Priority: - # Fn Argument - # Dataset metadata["download_url"] - # BASE_DOWNLOAD_URL/Dataset_name - download_urls = [] - if download_url is not None: - download_urls.append(download_url) - - metadata_url = self.metadata().get("download_url", None) - if metadata_url is not None: - download_urls.append(metadata_url) - - default_url = os.getenv("DEFAULT_DOWNLOAD_URL") - if default_url is not None: - if default_url.endswith("/"): - default_url = default_url[:-1] - default_url = f"{default_url}/{self.__class__.__name__}.zip" - download_urls.append(default_url) - - # Try downloading from available links in order of priority - for download_url in download_urls: - extension = ".zip" - supported_extensions = [ - ".tar.xz", - ".tar.bz2", - ".tar.gz", - ".xz", - ".bz2", - ".gz", - ".tar", - ".zip", - ] - - for ext in supported_extensions: - if download_url.endswith(ext): - extension = ext - break - try: - logging.info(f"Trying {download_url}") - retrieve( - download_url, - known_hash=None, - fname=f"{self.__class__.__name__}{extension}", - path=self.data_dir, - progressbar=True, - processor=decompress, - ) - # If it was a *.tar.* file, we can safely delete the - # intermediate *.tar file - if extension in supported_extensions[:3]: - tar_file_path = ( - Path(self.data_dir) / f"{self.__class__.__name__}.tar" - ) - tar_file_path.unlink() - return True - except Exception as e: - logging.warning(f"Failed to download: {e}") - continue - - logging.warning(f"Failed to download dataset") - - return False - def _deduplicate_train_test(self, train_data, test_data): """ Filter train data to avoid overlap with test data @@ -417,3 +287,151 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True): examples = [self._destringify_sample(sample) for sample in examples] yield examples + + def download_dataset(self, download_url=None): + """ + Utility method to download a dataset if not present locally on disk. + Can handle datasets of types *.zip, *.tar, *.tar.gz, *.tar.bz2, *.tar.xz. + + Arguments + --------- + download_url : str + The url to the dataset. If not provided, falls back to the `download_url` + provided by the Dataset's metadata. If missing, falls back to a default + server specified by the environment variable `DEFAULT_DOWNLOAD_URL` + + Returns + ------- + download_succeeded : bool + Returns True if the dataset is already present on disk, or if download + + extraction was successful. + """ + + def decompress(fname, action, pup): + """ + Post-processing hook to automatically detect the type of archive and + call the correct processor (UnZip, Untar, Decompress) + + Arguments + --------- + fname : str + Full path of the zipped file in local storage + action : str + One of "download" (file doesn't exist and will download), + "update" (file is outdated and will download), and + "fetch" (file exists and is updated so no download). + pup : Pooch + The instance of Pooch that called the processor function. + + Returns + ------- + fnames : list + List of all extracted files + + """ + # Default where the downloaded file is not a container/archive + fnames = [fname] + + extract_dir = self.__class__.__name__ + + if fname.endswith(".tar.xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".tar.gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".xz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".bz2"): + extractor = Decompress(name=fname[:-4]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".gz"): + extractor = Decompress(name=fname[:-3]) + fname = extractor(fname, action, pup) + fnames = [fname] + elif fname.endswith(".tar"): + extractor = Untar(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + elif fname.endswith(".zip"): + extractor = Unzip(extract_dir=extract_dir) + fnames = extractor(fname, action, pup) + + return fnames + + # Priority: + # Fn Argument + # Dataset metadata["download_url"] + # DEFAULT_DOWNLOAD_URL/Dataset_name.zip + download_urls = [] + if download_url is not None: + download_urls.append(download_url) + + metadata_url = self.metadata().get("download_url", None) + if metadata_url is not None: + download_urls.append(metadata_url) + + default_url = os.getenv("DEFAULT_DOWNLOAD_URL") + if default_url is not None: + if default_url.endswith("/"): + default_url = default_url[:-1] + default_url = f"{default_url}/{self.__class__.__name__}.zip" + download_urls.append(default_url) + + # Try downloading from available links in order of priority + for download_url in download_urls: + extension = ".zip" + supported_extensions = [ + ".tar.xz", + ".tar.bz2", + ".tar.gz", + ".xz", + ".bz2", + ".gz", + ".tar", + ".zip", + ] + + for ext in supported_extensions: + if download_url.endswith(ext): + extension = ext + break + try: + logging.info(f"Trying {download_url}") + retrieve( + download_url, + known_hash=None, + fname=f"{self.__class__.__name__}{extension}", + path=self.data_dir, + progressbar=True, + processor=decompress, + ) + # If it was a *.tar.* file, we can safely delete the + # intermediate *.tar file + if extension in supported_extensions[:3]: + tar_file_path = ( + Path(self.data_dir) / f"{self.__class__.__name__}.tar" + ) + tar_file_path.unlink() + return True + except Exception as e: + logging.warning(f"Failed to download: {e}") + continue + + logging.warning(f"Failed to download dataset") + + return False From c96fbe3afc521b3772408b0727b8d8ae36e01d2c Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Mon, 11 Sep 2023 14:05:07 +0300 Subject: [PATCH 10/10] Fix incorrect param handling in dataset_base init --- llmebench/datasets/dataset_base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index adc92b19..a95ff52e 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -24,7 +24,9 @@ class DatasetBase(ABC): Attributes ---------- - None + data_dir : str + Base path of data containing all datasets. Defaults to "data" in the current + working directory. Methods ------- @@ -50,8 +52,8 @@ class DatasetBase(ABC): """ - def __init__(self, **kwargs): - self.data_dir = kwargs.get("data_dir", "data") + def __init__(self, data_dir="data", **kwargs): + self.data_dir = data_dir @abstractmethod def metadata(self):