Skip to content

Commit

Permalink
Clean up code and add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
fdalvi committed Sep 11, 2023
1 parent 192de4d commit f46908f
Showing 1 changed file with 148 additions and 130 deletions.
278 changes: 148 additions & 130 deletions llmebench/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,136 +121,6 @@ def load_data(self, data_path, no_labels=False):
"""
pass

def download_dataset(self, download_url=None):
def decompress(fname, action, pup):
"""
Post-processing hook to automatically detect the type of archive and
call the correct processor (UnZip, Untar, Decompress)
Parameters
----------
fname : str
Full path of the zipped file in local storage
action : str
One of "download" (file doesn't exist and will download),
"update" (file is outdated and will download), and
"fetch" (file exists and is updated so no download).
pup : Pooch
The instance of Pooch that called the processor function.
Returns
-------
fnames : list
List of all extracted files
"""
# Default where the downloaded file is not a container/archive
fnames = [fname]

extract_dir = self.__class__.__name__

if fname.endswith(".tar.xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".tar"):
extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".zip"):
extractor = Unzip(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)

return fnames

# Priority:
# Fn Argument
# Dataset metadata["download_url"]
# BASE_DOWNLOAD_URL/Dataset_name
download_urls = []
if download_url is not None:
download_urls.append(download_url)

metadata_url = self.metadata().get("download_url", None)
if metadata_url is not None:
download_urls.append(metadata_url)

default_url = os.getenv("DEFAULT_DOWNLOAD_URL")
if default_url is not None:
if default_url.endswith("/"):
default_url = default_url[:-1]
default_url = f"{default_url}/{self.__class__.__name__}.zip"
download_urls.append(default_url)

# Try downloading from available links in order of priority
for download_url in download_urls:
extension = ".zip"
supported_extensions = [
".tar.xz",
".tar.bz2",
".tar.gz",
".xz",
".bz2",
".gz",
".tar",
".zip",
]

for ext in supported_extensions:
if download_url.endswith(ext):
extension = ext
break
try:
logging.info(f"Trying {download_url}")
retrieve(
download_url,
known_hash=None,
fname=f"{self.__class__.__name__}{extension}",
path=self.data_dir,
progressbar=True,
processor=decompress,
)
# If it was a *.tar.* file, we can safely delete the
# intermediate *.tar file
if extension in supported_extensions[:3]:
tar_file_path = (
Path(self.data_dir) / f"{self.__class__.__name__}.tar"
)
tar_file_path.unlink()
return True
except Exception as e:
logging.warning(f"Failed to download: {e}")
continue

logging.warning(f"Failed to download dataset")

return False

def _deduplicate_train_test(self, train_data, test_data):
"""
Filter train data to avoid overlap with test data
Expand Down Expand Up @@ -417,3 +287,151 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
examples = [self._destringify_sample(sample) for sample in examples]

yield examples

def download_dataset(self, download_url=None):
"""
Utility method to download a dataset if not present locally on disk.
Can handle datasets of types *.zip, *.tar, *.tar.gz, *.tar.bz2, *.tar.xz.
Arguments
---------
download_url : str
The url to the dataset. If not provided, falls back to the `download_url`
provided by the Dataset's metadata. If missing, falls back to a default
server specified by the environment variable `DEFAULT_DOWNLOAD_URL`
Returns
-------
download_succeeded : bool
Returns True if the dataset is already present on disk, or if download +
extraction was successful.
"""

def decompress(fname, action, pup):
"""
Post-processing hook to automatically detect the type of archive and
call the correct processor (UnZip, Untar, Decompress)
Arguments
---------
fname : str
Full path of the zipped file in local storage
action : str
One of "download" (file doesn't exist and will download),
"update" (file is outdated and will download), and
"fetch" (file exists and is updated so no download).
pup : Pooch
The instance of Pooch that called the processor function.
Returns
-------
fnames : list
List of all extracted files
"""
# Default where the downloaded file is not a container/archive
fnames = [fname]

extract_dir = self.__class__.__name__

if fname.endswith(".tar.xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".tar.gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)

extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".xz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".bz2"):
extractor = Decompress(name=fname[:-4])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".gz"):
extractor = Decompress(name=fname[:-3])
fname = extractor(fname, action, pup)
fnames = [fname]
elif fname.endswith(".tar"):
extractor = Untar(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)
elif fname.endswith(".zip"):
extractor = Unzip(extract_dir=extract_dir)
fnames = extractor(fname, action, pup)

return fnames

# Priority:
# Fn Argument
# Dataset metadata["download_url"]
# DEFAULT_DOWNLOAD_URL/Dataset_name.zip
download_urls = []
if download_url is not None:
download_urls.append(download_url)

metadata_url = self.metadata().get("download_url", None)
if metadata_url is not None:
download_urls.append(metadata_url)

default_url = os.getenv("DEFAULT_DOWNLOAD_URL")
if default_url is not None:
if default_url.endswith("/"):
default_url = default_url[:-1]
default_url = f"{default_url}/{self.__class__.__name__}.zip"
download_urls.append(default_url)

# Try downloading from available links in order of priority
for download_url in download_urls:
extension = ".zip"
supported_extensions = [
".tar.xz",
".tar.bz2",
".tar.gz",
".xz",
".bz2",
".gz",
".tar",
".zip",
]

for ext in supported_extensions:
if download_url.endswith(ext):
extension = ext
break
try:
logging.info(f"Trying {download_url}")
retrieve(
download_url,
known_hash=None,
fname=f"{self.__class__.__name__}{extension}",
path=self.data_dir,
progressbar=True,
processor=decompress,
)
# If it was a *.tar.* file, we can safely delete the
# intermediate *.tar file
if extension in supported_extensions[:3]:
tar_file_path = (
Path(self.data_dir) / f"{self.__class__.__name__}.tar"
)
tar_file_path.unlink()
return True
except Exception as e:
logging.warning(f"Failed to download: {e}")
continue

logging.warning(f"Failed to download dataset")

return False

0 comments on commit f46908f

Please sign in to comment.