Skip to content

Commit

Permalink
Implement download and caching utility for datasets (#214)
Browse files Browse the repository at this point in the history
This commit adds support for downloading and caching datasets for future use. Supports a variety of archives, include zip, tar, tar.gz and more. The utility method is not automatically called currently.

* Initial auto-download implementation

* Improve auto-downloader implementation

* Add tests for auto downloader

* Add tests for caching mechanism

* Add test for metadata download url usage over environment variable

* Generalize port for test server

* Add tests for tar.bz2 and tar.xz files

* Add test for non-existent datasets

* Clean up code and add docstrings

* Fix incorrect param handling in dataset_base init
  • Loading branch information
fdalvi authored Sep 11, 2023
1 parent e4393d0 commit 5a7906b
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 6 deletions.
163 changes: 158 additions & 5 deletions llmebench/datasets/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import logging
import os
import random

from abc import ABC, abstractmethod
from pathlib import Path

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts.example_selector import MaxMarginalRelevanceExampleSelector
from langchain.vectorstores import FAISS

from pooch import Decompress, Pooch, retrieve, Untar, Unzip


class DatasetBase(ABC):
"""
Expand All @@ -20,7 +24,9 @@ class DatasetBase(ABC):
Attributes
----------
None
data_dir : str
Base path of data containing all datasets. Defaults to "data" in the current
working directory.
Methods
-------
Expand All @@ -46,12 +52,11 @@ class DatasetBase(ABC):
"""

def __init__(self, **kwargs):
pass
def __init__(self, data_dir="data", **kwargs):
self.data_dir = data_dir

@staticmethod
@abstractmethod
def metadata():
def metadata(self):
"""
Returns the dataset's metadata
Expand Down Expand Up @@ -284,3 +289,151 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
examples = [self._destringify_sample(sample) for sample in examples]

yield examples

def download_dataset(self, download_url=None):
"""
Utility method to download a dataset if not present locally on disk.
Can handle datasets of types *.zip, *.tar, *.tar.gz, *.tar.bz2, *.tar.xz.
Arguments
---------
download_url : str
The url to the dataset. If not provided, falls back to the `download_url`
provided by the Dataset's metadata. If missing, falls back to a default
server specified by the environment variable `DEFAULT_DOWNLOAD_URL`
Returns
-------
download_succeeded : bool
Returns True if the dataset is already present on disk, or if download +
extraction was successful.
"""

def decompress(fname, action, pup):
"""
Post-processing hook to automatically detect the type of archive and
call the correct processor (UnZip, Untar, Decompress)
Arguments
---------
fname : str
Full path of the zipped file in local storage
action : str
One of "download" (file doesn't exist and will download),
"update" (file is outdated and will download), and
"fetch" (file exists and is updated so no download).
pup : Pooch
The instance of Pooch that called the processor function.
Returns
-------
fnames : list
List of all extracted files
"""
# Default where the downloaded file is not a container/archive
fnames = [fname]

extract_dir = self.__class__.__name__

if fname.endswith(".tar.xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".tar"):
extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".zip"):
extractor = Unzip(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)

return fnames

# Priority:
# Fn Argument
# Dataset metadata["download_url"]
# DEFAULT_DOWNLOAD_URL/Dataset_name.zip
download_urls = []
if download_url is not None:
download_urls.append(download_url)

metadata_url = self.metadata().get("download_url", None)
if metadata_url is not None:
download_urls.append(metadata_url)

default_url = os.getenv("DEFAULT_DOWNLOAD_URL")
if default_url is not None:
if default_url.endswith("/"):
default_url = default_url[:-1]
default_url = f"{default_url}/{self.__class__.__name__}.zip"
download_urls.append(default_url)

# Try downloading from available links in order of priority
for download_url in download_urls:
extension = ".zip"
supported_extensions = [
".tar.xz",
".tar.bz2",
".tar.gz",
".xz",
".bz2",
".gz",
".tar",
".zip",
]

for ext in supported_extensions:
if download_url.endswith(ext):
extension = ext
break
try:
logging.info(f"Trying {download_url}")
retrieve(
download_url,
known_hash=None,
fname=f"{self.__class__.__name__}{extension}",
path=self.data_dir,
progressbar=True,
processor=decompress,
)
# If it was a *.tar.* file, we can safely delete the
# intermediate *.tar file
if extension in supported_extensions[:3]:
tar_file_path = (
Path(self.data_dir) / f"{self.__class__.__name__}.tar"
)
tar_file_path.unlink()
return True
except Exception as e:
logging.warning(f"Failed to download: {e}")
continue

logging.warning(f"Failed to download dataset")

return False
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ install_requires =
nltk==3.8.1
openai==0.27.7
pandas==2.0.2
pooch==1.7.0
python-dotenv==1.0.0
scikit-learn==1.2.2
tenacity==8.2.2
Expand Down
Binary file added tests/datasets/archives/MockDataset.tar
Binary file not shown.
Binary file added tests/datasets/archives/MockDataset.tar.bz2
Binary file not shown.
Binary file added tests/datasets/archives/MockDataset.tar.gz
Binary file not shown.
Binary file added tests/datasets/archives/MockDataset.tar.xz
Binary file not shown.
Binary file added tests/datasets/archives/MockDataset.zip
Binary file not shown.
Loading

0 comments on commit 5a7906b

Please sign in to comment.