diff --git a/audeer/core/io.py b/audeer/core/io.py index b4a7efa..abcc37b 100644 --- a/audeer/core/io.py +++ b/audeer/core/io.py @@ -256,7 +256,7 @@ def download_url( 'favicon.png' """ # noqa: E501 - destination = safe_path(destination) + destination = safe_path(destination, follow_symlink=False) if os.path.isdir(destination): destination = os.path.join(destination, os.path.basename(url)) if os.path.exists(destination) and not force_download: diff --git a/tests/test_io.py b/tests/test_io.py index 072eb28..761baab 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -424,12 +424,51 @@ def test_common_directory(dirs, expected): assert common == expected -def test_download_url(tmpdir): +@pytest.mark.parametrize( + "destination, symlink", + [ + ("folder/", False), + ("folder/", True), + ("favicon.png", False), + ("favicon.png", True), + ], +) +def test_download_url(tmpdir, destination, symlink): url = "https://audeering.github.io/audeer/_static/favicon.png" - audeer.download_url(url, tmpdir) - audeer.download_url(url, tmpdir) - dst = audeer.download_url(url, tmpdir, force_download=True) - assert dst == os.path.join(tmpdir, os.path.basename(url)) + file = os.path.basename(url) + + if destination.endswith("/"): # folder + path = audeer.mkdir(tmpdir, destination) + link = os.path.join(tmpdir, "link") + else: + path = os.path.join(tmpdir, file) + link = os.path.join(tmpdir, "link.png") + + if symlink: + os.symlink(path, link) + dst = link + else: + dst = path + + # 1. Download + # 2. Download (just return dst link) + # 3. Download with overwrite + # + # If destination is a symlink, + # a symlink should be returned + for n in range(3): + force_download = False + if n == 2: + force_download = True + downloaded_file = audeer.download_url(url, dst, force_download=force_download) + if destination.endswith("/"): + assert downloaded_file == os.path.join(dst, file) + if symlink: + assert audeer.path(downloaded_file) == os.path.join(path, file) + else: + assert downloaded_file == dst + if symlink: + assert audeer.path(downloaded_file) == path @pytest.mark.parametrize(