diff --git a/conda_lock/lookup.py b/conda_lock/lookup.py index 1f0e86787..81746f162 100644 --- a/conda_lock/lookup.py +++ b/conda_lock/lookup.py @@ -1,4 +1,3 @@ -import hashlib import json import logging import time @@ -7,12 +6,10 @@ from pathlib import Path from typing import Dict, TypedDict -import requests - -from filelock import FileLock, Timeout -from packaging.utils import NormalizedName, canonicalize_name +from packaging.utils import NormalizedName from packaging.utils import canonicalize_name as canonicalize_pypi_name -from platformdirs import user_cache_path + +from conda_lock.lookup_cache import cached_download_file logger = logging.getLogger(__name__) @@ -31,7 +28,7 @@ class MappingEntry(TypedDict): def _get_pypi_lookup(mapping_url: str) -> Dict[NormalizedName, MappingEntry]: url = mapping_url if url.startswith("http://") or url.startswith("https://"): - content = cached_download_file(url) + content = cached_download_file(url, cache_subdir_name="pypi-mapping") else: if url.startswith("file://"): path = url[len("file://") :] @@ -51,9 +48,9 @@ def _get_pypi_lookup(mapping_url: str) -> Dict[NormalizedName, MappingEntry]: logger.debug(f"Loaded {len(lookup)} entries in {load_duration:.2f}s") # lowercase and kebabcase the pypi names assert lookup is not None - lookup = {canonicalize_name(k): v for k, v in lookup.items()} + lookup = {canonicalize_pypi_name(k): v for k, v in lookup.items()} for v in lookup.values(): - v["pypi_name"] = canonicalize_name(v["pypi_name"]) + v["pypi_name"] = canonicalize_pypi_name(v["pypi_name"]) return lookup @@ -68,18 +65,15 @@ def pypi_name_to_conda_name(name: str, mapping_url: str) -> str: 'zpfqzvrj' """ cname = canonicalize_pypi_name(name) - if cname in _get_pypi_lookup(mapping_url): - lookup = _get_pypi_lookup(mapping_url)[cname] - res = lookup.get("conda_name") or lookup.get("conda_forge") + lookup = _get_pypi_lookup(mapping_url) + if cname in lookup: + entry = lookup[cname] + res = entry.get("conda_name") or entry.get("conda_forge") if res is not None: return res - else: - logging.warning( - f"Could not find conda name for {cname}. Assuming identity." - ) - return cname - else: - return cname + + logger.warning(f"Could not find conda name for {cname}. Assuming identity.") + return cname @lru_cache(maxsize=None) @@ -96,90 +90,5 @@ def _get_conda_lookup(mapping_url: str) -> Dict[str, MappingEntry]: def conda_name_to_pypi_name(name: str, mapping_url: str) -> NormalizedName: """return the pypi name for a conda package""" lookup = _get_conda_lookup(mapping_url=mapping_url) - cname = canonicalize_name(name) + cname = canonicalize_pypi_name(name) return lookup.get(cname, {"pypi_name": cname})["pypi_name"] - - -def cached_download_file(url: str) -> bytes: - """Download a file and cache it in the user cache directory. - - If the file is already cached, return the cached contents. - If the file is not cached, download it and cache the contents - and the ETag. - - Protect against multiple processes downloading the same file. - """ - CLEAR_CACHE_AFTER_SECONDS = 60 * 60 * 24 * 2 # 2 days - DONT_CHECK_IF_NEWER_THAN_SECONDS = 60 * 5 # 5 minutes - current_time = time.time() - cache = user_cache_path("conda-lock", appauthor=False) - cache.mkdir(parents=True, exist_ok=True) - - # clear out old cache files - for file in cache.iterdir(): - if file.name.startswith("pypi-mapping-"): - mtime = file.stat().st_mtime - age = current_time - mtime - if age < 0 or age > CLEAR_CACHE_AFTER_SECONDS: - logger.debug("Removing old cache file %s", file) - file.unlink() - - url_hash = hashlib.sha256(url.encode()).hexdigest()[:4] - destination_mapping = cache / f"pypi-mapping-{url_hash}.yaml" - destination_etag = destination_mapping.with_suffix(".etag") - destination_lock = destination_mapping.with_suffix(".lock") - - # Wait for any other process to finish downloading the file. - # Use the ETag to avoid downloading the file if it hasn't changed. - # Otherwise, download the file and cache the contents and ETag. - while True: - try: - with FileLock(destination_lock, timeout=5): - # Return the contents immediately if the file is fresh - try: - mtime = destination_mapping.stat().st_mtime - age = current_time - mtime - if age < DONT_CHECK_IF_NEWER_THAN_SECONDS: - contents = destination_mapping.read_bytes() - logger.debug( - f"Using cached mapping {destination_mapping} without " - f"checking for updates" - ) - return contents - except FileNotFoundError: - pass - # Get the ETag from the last download, if it exists - if destination_mapping.exists() and destination_etag.exists(): - logger.debug(f"Old ETag found at {destination_etag}") - try: - old_etag = destination_etag.read_text().strip() - headers = {"If-None-Match": old_etag} - except FileNotFoundError: - logger.warning("Failed to read ETag") - headers = {} - else: - headers = {} - # Download the file and cache the result. - logger.debug(f"Requesting {url}") - res = requests.get(url, headers=headers) - if res.status_code == 304: - logger.debug( - f"{url} has not changed since last download, " - f"using {destination_mapping}" - ) - else: - res.raise_for_status() - time.sleep(10) - destination_mapping.write_bytes(res.content) - if "ETag" in res.headers: - destination_etag.write_text(res.headers["ETag"]) - else: - logger.warning("No ETag in response headers") - logger.debug(f"Downloaded {url} to {destination_mapping}") - return destination_mapping.read_bytes() - - except Timeout: - logger.warning( - f"Failed to acquire lock on {destination_lock}, it is likely " - f"being downloaded by another process. Retrying..." - ) diff --git a/conda_lock/lookup_cache.py b/conda_lock/lookup_cache.py new file mode 100644 index 000000000..a5ee6fda4 --- /dev/null +++ b/conda_lock/lookup_cache.py @@ -0,0 +1,179 @@ +import hashlib +import logging +import platform +import re + +from datetime import datetime +from pathlib import Path +from typing import Optional + +import requests + +from filelock import FileLock, Timeout +from platformdirs import user_cache_path + + +logger = logging.getLogger(__name__) + + +CLEAR_CACHE_AFTER_SECONDS = 60 * 60 * 24 * 2 # 2 days +"""Cached files older than this will be deleted.""" + +DONT_CHECK_IF_NEWER_THAN_SECONDS = 60 * 5 # 5 minutes +"""If the cached file is newer than this, just use it without checking for updates.""" + +WINDOWS_TIME_EPSILON = 0.005 +"""Windows has issues with file timestamps, so we add this small offset +to ensure that newly created files have a positive age. +""" + + +def uncached_download_file(url: str) -> bytes: + """The simple equivalent to cached_download_file.""" + res = requests.get(url, headers={"User-Agent": "conda-lock"}) + res.raise_for_status() + return res.content + + +def cached_download_file( + url: str, + *, + cache_subdir_name: str, + cache_root: Optional[Path] = None, + max_age_seconds: float = CLEAR_CACHE_AFTER_SECONDS, + dont_check_if_newer_than_seconds: float = DONT_CHECK_IF_NEWER_THAN_SECONDS, +) -> bytes: + """Download a file and cache it in the user cache directory. + + If the file is already cached, return the cached contents. + If the file is not cached, download it and cache the contents + and the ETag. + + Protect against multiple processes downloading the same file. + """ + if cache_root is None: + cache_root = user_cache_path("conda-lock", appauthor=False) + cache = cache_root / "cache" / cache_subdir_name + cache.mkdir(parents=True, exist_ok=True) + clear_old_files_from_cache(cache, max_age_seconds=max_age_seconds) + + destination_lock = (cache / cached_filename_for_url(url)).with_suffix(".lock") + + # Wait for any other process to finish downloading the file. + # This way we can use the result from the current download without + # spawning multiple concurrent downloads. + while True: + try: + with FileLock(str(destination_lock), timeout=5): + return _download_to_or_read_from_cache( + url, + cache=cache, + dont_check_if_newer_than_seconds=dont_check_if_newer_than_seconds, + ) + except Timeout: + logger.warning( + f"Failed to acquire lock on {destination_lock}, it is likely " + f"being downloaded by another process. Retrying..." + ) + + +def _download_to_or_read_from_cache( + url: str, *, cache: Path, dont_check_if_newer_than_seconds: float +) -> bytes: + """Download a file to the cache directory and return the contents. + + This function is designed to be called from within a FileLock context to avoid + multiple processes downloading the same file. + + If the file is newer than `dont_check_if_newer_than_seconds`, then immediately + return the cached contents. Otherwise we pass the ETag from the last download + in the headers to avoid downloading the file if it hasn't changed remotely. + """ + destination = cache / cached_filename_for_url(url) + destination_etag = destination.with_suffix(".etag") + request_headers = {"User-Agent": "conda-lock"} + # Return the contents immediately if the file is fresh + if destination.is_file(): + age_seconds = get_age_seconds(destination) + if 0 <= age_seconds < dont_check_if_newer_than_seconds: + logger.debug( + f"Using cached mapping {destination} of age {age_seconds}s " + f"without checking for updates" + ) + return destination.read_bytes() + # Add the ETag from the last download, if it exists, to the headers. + # The ETag is used to avoid downloading the file if it hasn't changed remotely. + # Otherwise, download the file and cache the contents and ETag. + if destination_etag.is_file(): + old_etag = destination_etag.read_text().strip() + request_headers["If-None-Match"] = old_etag + # Download the file and cache the result. + logger.debug(f"Requesting {url}") + res = requests.get(url, headers=request_headers) + if res.status_code == 304: + logger.debug(f"{url} has not changed since last download, using {destination}") + else: + res.raise_for_status() + destination.write_bytes(res.content) + if "ETag" in res.headers: + destination_etag.write_text(res.headers["ETag"]) + else: + logger.warning("No ETag in response headers") + logger.debug(f"Downloaded {url} to {destination}") + return destination.read_bytes() + + +def cached_filename_for_url(url: str) -> str: + """Return a filename for a URL that is probably unique to the URL. + + The filename is a 4-character hash of the URL, followed by the extension. + If the extension is not alphanumeric or too long, it is omitted. + + >>> cached_filename_for_url("https://example.com/foo.json") + 'a5d7.json' + >>> cached_filename_for_url("https://example.com/foo") + '5ea6' + >>> cached_filename_for_url("https://example.com/foo.bär") + '2191' + >>> cached_filename_for_url("https://example.com/foo.baaaaaar") + '1861' + """ + url_hash = hashlib.sha256(url.encode()).hexdigest()[:4] + extension = url.split(".")[-1] + if len(extension) <= 6 and re.match("^[a-zA-Z0-9]+$", extension): + return f"{url_hash}.{extension}" + else: + return f"{url_hash}" + + +def clear_old_files_from_cache(cache: Path, *, max_age_seconds: float) -> None: + """Remove files in the cache directory older than `max_age_seconds`. + + Also removes any files that somehow have a modification time in the future. + + For safety, this raises an error if `cache` is not a subdirectory of + a directory named `"cache"`. + """ + if not cache.parent.name == "cache": + raise ValueError( + f"Expected cache directory, got {cache}. Parent should be 'cache' ", + f"not '{cache.parent.name}'", + ) + for file in cache.iterdir(): + age_seconds = get_age_seconds(file) + if age_seconds < 0 or age_seconds >= max_age_seconds: + logger.debug(f"Removing old cache file {file} of age {age_seconds}s") + file.unlink() + + +def get_age_seconds(path: Path) -> float: + """Return the age of a file in seconds. + + On Windows, the age of a new file is sometimes slightly negative, so we add a small + offset to ensure that the age is positive. + """ + raw_age = datetime.now().timestamp() - path.stat().st_mtime + if platform.system() == "Windows": + return raw_age + WINDOWS_TIME_EPSILON + else: + return raw_age diff --git a/tests/test_lookup_cache.py b/tests/test_lookup_cache.py new file mode 100644 index 000000000..004aaf801 --- /dev/null +++ b/tests/test_lookup_cache.py @@ -0,0 +1,322 @@ +import os +import queue +import threading +import time + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from conda_lock.lookup import DEFAULT_MAPPING_URL +from conda_lock.lookup_cache import ( + cached_download_file, + clear_old_files_from_cache, + uncached_download_file, +) + + +@pytest.fixture +def mock_cache_dir(tmp_path): + cache_dir = tmp_path / "cache" / "test_cache" + cache_dir.mkdir(parents=True) + return cache_dir + + +@pytest.mark.parametrize("use_caching_function", [True, False]) +def test_download_file_uncached(tmp_path, use_caching_function): + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.content = b"test content" + mock_get.return_value = mock_response + + if use_caching_function: + result = cached_download_file( + "https://example.com/test", + cache_subdir_name="test_cache", + cache_root=tmp_path, + ) + else: + result = uncached_download_file("https://example.com/test") + + assert result == b"test content" + mock_get.assert_called_once_with( + "https://example.com/test", headers={"User-Agent": "conda-lock"} + ) + mock_response.raise_for_status.assert_called_once() + + +def test_clear_old_files_from_cache(mock_cache_dir): + """Verify that files older than the max age are removed.""" + old_file = mock_cache_dir / "old_file.txt" + recent_file = mock_cache_dir / "recent_file.txt" + future_file = mock_cache_dir / "future_file.txt" + + old_file.touch() + recent_file.touch() + future_file.touch() + + # Set the modification and access times of each file + t = time.time() + os.utime(old_file, (t - 100, t - 100)) + os.utime(recent_file, (t - 20, t - 20)) + os.utime(future_file, (t + 100, t + 100)) + + clear_old_files_from_cache(mock_cache_dir, max_age_seconds=22) + + # Only the recent file is in the correct time range + assert not old_file.exists() + assert recent_file.exists() + assert not future_file.exists() + + # Immediately rerunning it again should not change anything + clear_old_files_from_cache(mock_cache_dir, max_age_seconds=22) + assert recent_file.exists() + + # Lowering the max age should remove the file + clear_old_files_from_cache(mock_cache_dir, max_age_seconds=20) + assert not recent_file.exists() + + +def test_clear_old_files_from_cache_invalid_directory(tmp_path): + """Verify that only paths within a 'cache' directory are accepted. + + This is a safety measure to prevent accidental deletion of files + outside of a cache directory. + """ + valid_cache_dir = tmp_path / "cache" / "valid" + invalid_cache_dir = tmp_path / "not-cache" / "invalid" + + valid_cache_dir.mkdir(parents=True) + invalid_cache_dir.mkdir(parents=True) + clear_old_files_from_cache(valid_cache_dir, max_age_seconds=10) + with pytest.raises(ValueError): + clear_old_files_from_cache(Path(invalid_cache_dir), max_age_seconds=10) + + +def test_cached_download_file(tmp_path): + """Simulate an interaction with a remote server to test the cache. + + * Download the file for the first time + * Retrieve the file again immediately (should be cached without sending a request) + * Retrieve the file again twice more but check that the remote file has been updated + (should get 304 Not Modified and return the cached version) + * Retrieve the file again but check that the remote file has been updated + (should get 200 OK and return the updated version) + * Retrieve the file again immediately (should be cached without sending a request) + """ + url = "https://example.com/test.json" + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.content = b"previous content" + mock_response.status_code = 200 + mock_response.headers = {"ETag": "previous-etag"} + mock_get.return_value = mock_response + + # Warm the cache + result = cached_download_file( + url, cache_subdir_name="test_cache", cache_root=tmp_path + ) + assert result == b"previous content" + assert mock_get.call_count == 1 + # No ETag should have been sent because we downloaded for the first time + assert mock_get.call_args[1]["headers"].get("If-None-Match") is None + + # Calling again immediately should directly return the cached result + # without sending a new request + result = cached_download_file( + url, cache_subdir_name="test_cache", cache_root=tmp_path + ) + assert result == b"previous content" + assert mock_get.call_count == 1 + + # Now we test HTTP 304 Not Modified + # We trigger a request by setting dont_check_if_newer_than_seconds to 0 + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.content = b"Should be ignored" + mock_response.status_code = 304 + mock_response.headers = {"ETag": "Should be ignored"} + mock_get.return_value = mock_response + + for call_count in range(1, 2 + 1): + # This time we should send the ETag and get a 304 + result = cached_download_file( + url, + cache_subdir_name="test_cache", + cache_root=tmp_path, + dont_check_if_newer_than_seconds=0, + ) + assert result == b"previous content" + assert mock_get.call_count == call_count + assert ( + mock_get.call_args[1]["headers"].get("If-None-Match") == "previous-etag" + ) + + # Now we test HTTP 200 OK with a new ETag to simulate the remote file being updated + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.content = b"new content" + mock_response.status_code = 200 + mock_response.headers = {"ETag": "new-etag"} + mock_get.return_value = mock_response + + result = cached_download_file( + url, + cache_subdir_name="test_cache", + cache_root=tmp_path, + dont_check_if_newer_than_seconds=0, + ) + assert result == b"new content" + assert mock_get.call_count == 1 + assert mock_get.call_args[1]["headers"].get("If-None-Match") == "previous-etag" + + # Verify that we picked up the new content and sent the new ETag + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.content = b"Should be ignored" + mock_response.status_code = 304 + mock_response.headers = {"ETag": "Should be ignored"} + mock_get.return_value = mock_response + + result = cached_download_file( + url, + cache_subdir_name="test_cache", + cache_root=tmp_path, + dont_check_if_newer_than_seconds=0, + ) + assert result == b"new content" + assert mock_get.call_count == 1 + assert mock_get.call_args[1]["headers"].get("If-None-Match") == "new-etag" + + # Verify that we return the updated content without sending a new request + result = cached_download_file( + url, + cache_subdir_name="test_cache", + cache_root=tmp_path, + ) + assert result == b"new content" + assert mock_get.call_count == 1 + + +def test_download_mapping_file(tmp_path): + """Verify that we can download the actual mapping file and that it is cached.""" + url = DEFAULT_MAPPING_URL + from requests import get as requests_get + + responses: list[requests.Response] = [] + + def wrapped_get(*args, **kwargs): + """Wrap requests.get to capture the response.""" + response = requests_get(*args, **kwargs) + responses.append(response) + return response + + # Initial download and cache + with patch("requests.get", wraps=wrapped_get) as mock_get: + result = cached_download_file( + url, cache_subdir_name="test_cache", cache_root=tmp_path + ) + # Ensure the response is valid and content is as expected + assert len(responses) == 1 + response = responses[0] + assert response.status_code == 200 + assert len(response.content) > 10000 + assert response.content == result + + # Verify that the file is retrieved from cache + with patch("requests.get", wraps=wrapped_get) as mock_get: + result2 = cached_download_file( + url, cache_subdir_name="test_cache", cache_root=tmp_path + ) + mock_get.assert_not_called() + assert result == result2 + + # Force cache refresh and verify ETag handling + with patch("requests.get", wraps=wrapped_get) as mock_get: + result3 = cached_download_file( + url, + cache_subdir_name="test_cache", + cache_root=tmp_path, + dont_check_if_newer_than_seconds=0, + ) + # Ensure the request is made and the response is 304 Not Modified + assert len(responses) == 2 + response = responses[1] + assert response is not None + mock_get.assert_called_once() + assert response.status_code == 304 + assert len(response.content) == 0 + assert result == result2 == result3 + + +def test_concurrent_cached_download_file(tmp_path): + """Test concurrent access to cached_download_file with 5 threads.""" + url = "https://example.com/test.json" + results: queue.Queue[bytes] = queue.Queue() + thread_names_emitting_lock_warnings: queue.Queue[str] = queue.Queue() + thread_names_calling_requests_get: queue.Queue[str] = queue.Queue() + + def mock_get(*args, **kwargs): + time.sleep(5.2) + response = MagicMock() + response.content = b"content" + response.status_code = 200 + thread_name = threading.current_thread().name + thread_names_calling_requests_get.put(thread_name) + return response + + def download_file(result_queue): + """Download the file in a thread and store the result in a queue.""" + import random + + # Randomize which thread calls cached_download_file first + time.sleep(random.uniform(0, 0.1)) + result = cached_download_file( + url, cache_subdir_name="test_cache", cache_root=tmp_path + ) + result_queue.put(result) + + with patch("requests.get", side_effect=mock_get) as mock_get, patch( + "conda_lock.lookup_cache.logger" + ) as mock_logger: + # Set up the logger to record which threads emit warnings + def mock_warning(msg, *args, **kwargs): + if "Failed to acquire lock" in msg: + thread_names_emitting_lock_warnings.put(threading.current_thread().name) + + mock_logger.warning.side_effect = mock_warning + + # Create and start 5 threads + thread_names = [f"CachedDownloadFileThread-{i}" for i in range(5)] + threads = [ + threading.Thread(target=download_file, args=(results,), name=thread_name) + for thread_name in thread_names + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Collect results from the queue + assert results.qsize() == len(threads) + assert all(result == b"content" for result in results.queue) + + # We expect one thread to have made the request and the other four + # to have emitted warnings. + assert ( + thread_names_calling_requests_get.qsize() + == 1 + == len(set(thread_names_calling_requests_get.queue)) + == mock_get.call_count + ), f"{thread_names_calling_requests_get.queue=}" + assert ( + thread_names_emitting_lock_warnings.qsize() + == 4 + == len(set(thread_names_emitting_lock_warnings.queue)) + ), f"{thread_names_emitting_lock_warnings.queue=}" + assert set(thread_names) == set( + thread_names_calling_requests_get.queue + + thread_names_emitting_lock_warnings.queue + )