Skip to content

Commit

Permalink
Change default server url behavior - function argument instead of env…
Browse files Browse the repository at this point in the history
… var
  • Loading branch information
fdalvi committed Oct 25, 2023
1 parent bc44725 commit 3879d52
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
16 changes: 13 additions & 3 deletions llmebench/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def download_dataset(cls, data_dir, download_url=None, default_url=None):
download_url : str
The url to the dataset. If not provided, falls back to the `download_url`
provided by the Dataset's metadata. If missing, falls back to a default
server specified by the environment variable `DEFAULT_DOWNLOAD_URL`
server specified by the `default_url` argument
default_url : str
Default server url to fall back to incase of missing download_urls
Returns
-------
Expand Down Expand Up @@ -411,14 +413,18 @@ def decompress(fname, action, pup):
# Priority:
# Fn Argument
# Dataset metadata["download_url"]
# DEFAULT_DOWNLOAD_URL/Dataset_name.zip
# default_url/Dataset_name.zip
download_urls = []
if download_url is not None:
download_urls.append(download_url)

metadata_url = cls.metadata().get("download_url", None)
if metadata_url is not None:
download_urls.append(metadata_url)
else:
logging.warning(
f"No default download url specified for {dataset_name}, will try to download from LLMeBench servers."
)

if default_url is not None:
if default_url.endswith("/"):
Expand All @@ -445,7 +451,10 @@ def decompress(fname, action, pup):
extension = ext
break
try:
logging.info(f"Trying {download_url}")
logging.info(f"Trying to fetch from {download_url}")
if (Path(data_dir) / f"{dataset_name}{extension}").exists():
logging.info(f"Cached dataset found")
return True
retrieve(
download_url,
known_hash=None,
Expand All @@ -462,6 +471,7 @@ def decompress(fname, action, pup):
if extension in supported_extensions[:3]:
tar_file_path = Path(data_dir) / f"{dataset_name}.tar"
tar_file_path.unlink()
logging.info(f"Fetch successful")
return True
except Exception as e:
logging.warning(f"Failed to download: {e}")
Expand Down
12 changes: 5 additions & 7 deletions tests/datasets/test_download_and_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,11 @@ def test_auto_download_default_url(self):
data_dir_path = Path(data_dir.name)

dataset = MockDataset(data_dir=data_dir_path)
with patch.dict(
"os.environ",
{
"DEFAULT_DOWNLOAD_URL": f"http://localhost:{self.port}/",
},
):
self.assertTrue(dataset.download_dataset(data_dir=data_dir.name))
self.assertTrue(
dataset.download_dataset(
data_dir=data_dir.name, default_url=f"http://localhost:{self.port}/"
)
)

self.check_downloaded(data_dir_path, "Mock", "zip")

Expand Down

0 comments on commit 3879d52

Please sign in to comment.