Skip to content

Commit

Permalink
Do now follow symlinks in download_url()
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Jan 18, 2024
1 parent dc375de commit fd52eee
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
2 changes: 1 addition & 1 deletion audeer/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def download_url(
'favicon.png'
""" # noqa: E501
destination = safe_path(destination)
destination = safe_path(destination, follow_symlink=False)
if os.path.isdir(destination):
destination = os.path.join(destination, os.path.basename(url))
if os.path.exists(destination) and not force_download:
Expand Down
49 changes: 44 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,51 @@ def test_common_directory(dirs, expected):
assert common == expected


def test_download_url(tmpdir):
@pytest.mark.parametrize(
"destination, symlink",
[
("folder/", False),
("folder/", True),
("favicon.png", False),
("favicon.png", True),
],
)
def test_download_url(tmpdir, destination, symlink):
url = "https://audeering.github.io/audeer/_static/favicon.png"
audeer.download_url(url, tmpdir)
audeer.download_url(url, tmpdir)
dst = audeer.download_url(url, tmpdir, force_download=True)
assert dst == os.path.join(tmpdir, os.path.basename(url))
file = os.path.basename(url)

if destination.endswith("/"): # folder
path = audeer.mkdir(tmpdir, destination)
link = os.path.join(tmpdir, "link")
else:
path = os.path.join(tmpdir, file)
link = os.path.join(tmpdir, "link.png")

if symlink:
os.symlink(path, link)
dst = link
else:
dst = path

# 1. Download
# 2. Download (just return dst link)
# 3. Download with overwrite
#
# If destination is a symlink,
# a symlink should be returned
for n in range(3):
force_download = False
if n == 2:
force_download = True
downloaded_file = audeer.download_url(url, dst, force_download=force_download)
if destination.endswith("/"):
assert downloaded_file == os.path.join(dst, file)
if symlink:
assert audeer.path(downloaded_file) == os.path.join(path, file)
else:
assert downloaded_file == dst
if symlink:
assert audeer.path(downloaded_file) == path


@pytest.mark.parametrize(
Expand Down

0 comments on commit fd52eee

Please sign in to comment.