diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7a50881..8737b2af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,3 +53,11 @@ repos: - flake8-comprehensions - flake8-tidy-imports - flake8-typing-imports +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.0.0 + hooks: + - id: mypy + additional_dependencies: + - django-stubs==1.14.0 + - requests + - types-requests diff --git a/pyproject.toml b/pyproject.toml index 99efa39d..16d0e470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,21 @@ build-backend = "setuptools.build_meta" [tool.black] target-version = ['py37'] +[tool.django-stubs] +django_settings_module = "tests.django_settings" + +[tool.mypy] +mypy_path = "src/" +namespace_packages = false +plugins = ["mypy_django_plugin.main"] +show_error_codes = true +strict = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "tests.*" +allow_untyped_defs = true + [tool.pytest.ini_options] addopts = """\ --strict-config diff --git a/requirements/py37-django32.txt b/requirements/py37-django32.txt index 220d7e1e..1459bdab 100644 --- a/requirements/py37-django32.txt +++ b/requirements/py37-django32.txt @@ -47,8 +47,9 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest -typing-extensions==4.4.0 +typing-extensions==4.4.0 ; python_version < "3.10" # via + # -r requirements.in # asgiref # importlib-metadata urllib3==1.26.14 diff --git a/requirements/py38-django32.txt b/requirements/py38-django32.txt index d093b450..990eca12 100644 --- a/requirements/py38-django32.txt +++ b/requirements/py38-django32.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py38-django40.txt b/requirements/py38-django40.txt index 35887a04..a0496309 100644 --- a/requirements/py38-django40.txt +++ b/requirements/py38-django40.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py38-django41.txt b/requirements/py38-django41.txt index f2d385df..32e455d9 100644 --- a/requirements/py38-django41.txt +++ b/requirements/py38-django41.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py38-django42.txt b/requirements/py38-django42.txt index 0788219e..6140db25 100644 --- a/requirements/py38-django42.txt +++ b/requirements/py38-django42.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py39-django32.txt b/requirements/py39-django32.txt index 80f81c5f..c2665a8d 100644 --- a/requirements/py39-django32.txt +++ b/requirements/py39-django32.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py39-django40.txt b/requirements/py39-django40.txt index 2b493aac..ec62cc4e 100644 --- a/requirements/py39-django40.txt +++ b/requirements/py39-django40.txt @@ -42,6 +42,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py39-django41.txt b/requirements/py39-django41.txt index 21b99099..c492d35a 100644 --- a/requirements/py39-django41.txt +++ b/requirements/py39-django41.txt @@ -42,6 +42,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/py39-django42.txt b/requirements/py39-django42.txt index bb54284a..edb87fea 100644 --- a/requirements/py39-django42.txt +++ b/requirements/py39-django42.txt @@ -42,6 +42,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.14 # via requests zipp==3.13.0 diff --git a/requirements/requirements.in b/requirements/requirements.in index 11b521d4..e97f34ed 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -4,3 +4,4 @@ django pytest pytest-randomly requests +typing-extensions ; python_version < "3.10" diff --git a/setup.cfg b/setup.cfg index c15fc174..2e821e20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,8 @@ project_urls = [options] packages = find: +install_requires = + typing-extensions;python_version < "3.10" python_requires = >=3.7 include_package_data = True package_dir = diff --git a/src/whitenoise/base.py b/src/whitenoise/base.py index 45c58f92..a7e1161d 100644 --- a/src/whitenoise/base.py +++ b/src/whitenoise/base.py @@ -5,9 +5,14 @@ import warnings from posixpath import normpath from typing import Callable +from typing import Generator +from typing import Iterable from wsgiref.headers import Headers from wsgiref.util import FileWrapper +from .compat import StartResponse +from .compat import WSGIApplication +from .compat import WSGIEnvironment from .media_types import MediaTypes from .responders import IsDirectoryError from .responders import MissingFileError @@ -24,9 +29,9 @@ class WhiteNoise: def __init__( self, - application, - root=None, - prefix=None, + application: WSGIApplication, + root: str | os.PathLike[str] | None = None, + prefix: str | None = None, *, # Re-check the filesystem on every request so that any changes are # automatically picked up. NOTE: For use in development only, not supported @@ -43,7 +48,7 @@ def __init__( mimetypes: dict[str, str] | None = None, add_headers_function: Callable[[Headers, str, str], None] | None = None, index_file: str | bool | None = None, - immutable_file_test: Callable | str | None = None, + immutable_file_test: Callable[[str, str], bool] | str | None = None, ): self.autorefresh = autorefresh self.max_age = max_age @@ -60,18 +65,24 @@ def __init__( if immutable_file_test is not None: if not callable(immutable_file_test): regex = re.compile(immutable_file_test) - self.immutable_file_test = lambda path, url: bool(regex.search(url)) + self.immutable_file_test: Callable[ + [str, str], bool + ] = lambda path, url: bool(regex.search(url)) else: self.immutable_file_test = immutable_file_test + else: + self.immutable_file_test = lambda path, url: False self.media_types = MediaTypes(extra_types=mimetypes) self.application = application - self.files = {} - self.directories = [] + self.files: dict[str, Redirect | StaticFile] = {} + self.directories: list[tuple[str, str]] = [] if root is not None: self.add_files(root, prefix) - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: path = decode_path_info(environ.get("PATH_INFO", "")) if self.autorefresh: static_file = self.find_file(path) @@ -83,17 +94,26 @@ def __call__(self, environ, start_response): return self.serve(static_file, environ, start_response) @staticmethod - def serve(static_file, environ, start_response): + def serve( + static_file: Redirect | StaticFile, + environ: WSGIEnvironment, + start_response: StartResponse, + ) -> Iterable[bytes]: response = static_file.get_response(environ["REQUEST_METHOD"], environ) status_line = f"{response.status} {response.status.phrase}" start_response(status_line, list(response.headers)) if response.file is not None: - file_wrapper = environ.get("wsgi.file_wrapper", FileWrapper) - return file_wrapper(response.file) + file_wrapper: type[FileWrapper] = environ.get( + "wsgi.file_wrapper", FileWrapper + ) + # It's fine to pass BufferedIOBase to FileWrapper + return file_wrapper(response.file) # type: ignore [arg-type] else: return [] - def add_files(self, root, prefix=None): + def add_files( + self, root: str | os.PathLike[str], prefix: str | None = None + ) -> None: root = os.path.abspath(root) root = root.rstrip(os.path.sep) + os.path.sep prefix = ensure_leading_trailing_slash(prefix) @@ -108,7 +128,7 @@ def add_files(self, root, prefix=None): else: warnings.warn(f"No directory at: {root}") - def update_files_dictionary(self, root, prefix): + def update_files_dictionary(self, root: str, prefix: str) -> None: # Build a mapping from paths to the results of `os.stat` calls # so we only have to touch the filesystem once stat_cache = dict(scantree(root)) @@ -118,7 +138,12 @@ def update_files_dictionary(self, root, prefix): url = prefix + relative_url self.add_file_to_dictionary(url, path, stat_cache=stat_cache) - def add_file_to_dictionary(self, url, path, stat_cache=None): + def add_file_to_dictionary( + self, + url: str, + path: str, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> None: if self.is_compressed_variant(path, stat_cache=stat_cache): return if self.index_file is not None and url.endswith("/" + self.index_file): @@ -130,26 +155,27 @@ def add_file_to_dictionary(self, url, path, stat_cache=None): static_file = self.get_static_file(path, url, stat_cache=stat_cache) self.files[url] = static_file - def find_file(self, url): + def find_file(self, url: str) -> Redirect | StaticFile | None: # Optimization: bail early if the URL can never match a file if self.index_file is None and url.endswith("/"): - return + return None if not self.url_is_canonical(url): - return + return None for path in self.candidate_paths_for_url(url): try: return self.find_file_at_path(path, url) except MissingFileError: pass + return None - def candidate_paths_for_url(self, url): + def candidate_paths_for_url(self, url: str) -> Generator[str, None, None]: for root, prefix in self.directories: if url.startswith(prefix): path = os.path.join(root, url[len(prefix) :]) if os.path.commonprefix((root, path)) == root: yield path - def find_file_at_path(self, path, url): + def find_file_at_path(self, path: str, url: str) -> Redirect | StaticFile: if self.is_compressed_variant(path): raise MissingFileError(path) @@ -171,7 +197,7 @@ def find_file_at_path(self, path, url): return self.get_static_file(path, url) @staticmethod - def url_is_canonical(url): + def url_is_canonical(url: str) -> bool: """ Check that the URL path is in canonical format i.e. has normalised slashes and no path traversal elements @@ -184,7 +210,9 @@ def url_is_canonical(url): return normalised == url @staticmethod - def is_compressed_variant(path, stat_cache=None): + def is_compressed_variant( + path: str, stat_cache: dict[str, os.stat_result] | None = None + ) -> bool: if path[-3:] in (".gz", ".br"): uncompressed_path = path[:-3] if stat_cache is None: @@ -193,7 +221,12 @@ def is_compressed_variant(path, stat_cache=None): return uncompressed_path in stat_cache return False - def get_static_file(self, path, url, stat_cache=None): + def get_static_file( + self, + path: str, + url: str, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> StaticFile: # Optimization: bail early if file does not exist if stat_cache is None and not os.path.exists(path): raise MissingFileError(path) @@ -211,7 +244,7 @@ def get_static_file(self, path, url, stat_cache=None): encodings={"gzip": path + ".gz", "br": path + ".br"}, ) - def add_mime_headers(self, headers, path, url): + def add_mime_headers(self, headers: Headers, path: str, url: str) -> None: media_type = self.media_types.get_type(path) if media_type.startswith("text/"): params = {"charset": str(self.charset)} @@ -219,7 +252,7 @@ def add_mime_headers(self, headers, path, url): params = {} headers.add_header("Content-Type", str(media_type), **params) - def add_cache_headers(self, headers, path, url): + def add_cache_headers(self, headers: Headers, path: str, url: str) -> None: if self.immutable_file_test(path, url): headers["Cache-Control"] = "max-age={}, public, immutable".format( self.FOREVER @@ -227,20 +260,14 @@ def add_cache_headers(self, headers, path, url): elif self.max_age is not None: headers["Cache-Control"] = f"max-age={self.max_age}, public" - def immutable_file_test(self, path, url): - """ - This should be implemented by sub-classes (see e.g. WhiteNoiseMiddleware) - or by setting the `immutable_file_test` config option - """ - return False - - def redirect(self, from_url, to_url): + def redirect(self, from_url: str, to_url: str) -> Redirect: """ Return a relative 302 redirect We use relative redirects as we don't know the absolute URL the app is being hosted under """ + assert self.index_file is not None if to_url == from_url + "/": relative_url = from_url.split("/")[-1] + "/" elif from_url == to_url + self.index_file: @@ -254,7 +281,7 @@ def redirect(self, from_url, to_url): return Redirect(relative_url, headers=headers) -def scantree(root): +def scantree(root: str) -> Generator[tuple[str, os.stat_result], None, None]: """ Recurse the given directory yielding (pathname, os.stat(pathname)) pairs """ diff --git a/src/whitenoise/compat.py b/src/whitenoise/compat.py new file mode 100644 index 00000000..8f61e940 --- /dev/null +++ b/src/whitenoise/compat.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 11): + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment +else: + from collections.abc import Callable, Iterable, Iterator + from types import TracebackType + from typing import Any, Dict, Protocol, Tuple, Type, Union + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + _ExcInfo: TypeAlias = Tuple[Type[BaseException], BaseException, TracebackType] + _OptExcInfo: TypeAlias = Union[_ExcInfo, Tuple[None, None, None]] + + class StartResponse(Protocol): + def __call__( + self, + __status: str, + __headers: list[tuple[str, str]], + __exc_info: _OptExcInfo | None = ..., + ) -> Callable[[bytes], object]: + ... + + WSGIEnvironment: TypeAlias = Dict[str, Any] + WSGIApplication: TypeAlias = Callable[ + [WSGIEnvironment, StartResponse], Iterable[bytes] + ] + + class InputStream(Protocol): + def read(self, __size: int = ...) -> bytes: + ... + + def readline(self, __size: int = ...) -> bytes: + ... + + def readlines(self, __hint: int = ...) -> list[bytes]: + ... + + def __iter__(self) -> Iterator[bytes]: + ... + + class ErrorStream(Protocol): + def flush(self) -> object: + ... + + def write(self, __s: str) -> object: + ... + + def writelines(self, __seq: list[str]) -> object: + ... + + class _Readable(Protocol): + def read(self, __size: int = ...) -> bytes: + ... + + # Optional: def close(self) -> object: ... + + class FileWrapper(Protocol): + def __call__( + self, __file: _Readable, __block_size: int = ... + ) -> Iterable[bytes]: + ... + + +__all__ = [ + "StartResponse", + "WSGIApplication", + "WSGIEnvironment", +] diff --git a/src/whitenoise/compress.py b/src/whitenoise/compress.py index 143e1e44..0cf7c20b 100644 --- a/src/whitenoise/compress.py +++ b/src/whitenoise/compress.py @@ -5,6 +5,10 @@ import os import re from io import BytesIO +from typing import Callable +from typing import Generator +from typing import Pattern +from typing import Sequence try: import brotli @@ -14,6 +18,10 @@ brotli_installed = False +def noop_log(message: str) -> None: + pass + + class Compressor: # Extensions that it's not worth trying to compress SKIP_COMPRESS_EXTENSIONS = ( @@ -52,8 +60,13 @@ class Compressor: ) def __init__( - self, extensions=None, use_gzip=True, use_brotli=True, log=print, quiet=False - ): + self, + extensions: Sequence[str] | None = None, + use_gzip: bool = True, + use_brotli: bool = True, + log: Callable[[str], None] = print, + quiet: bool = False, + ) -> None: if extensions is None: extensions = self.SKIP_COMPRESS_EXTENSIONS self.extension_re = self.get_extension_re(extensions) @@ -61,9 +74,11 @@ def __init__( self.use_brotli = use_brotli and brotli_installed if not quiet: self.log = log + else: + self.log = noop_log @staticmethod - def get_extension_re(extensions): + def get_extension_re(extensions: Sequence[str]) -> Pattern[str]: if not extensions: return re.compile("^$") else: @@ -71,13 +86,10 @@ def get_extension_re(extensions): r"\.({})$".format("|".join(map(re.escape, extensions))), re.IGNORECASE ) - def should_compress(self, filename): + def should_compress(self, filename: str) -> bool: return not self.extension_re.search(filename) - def log(self, message): - pass - - def compress(self, path): + def compress(self, path: str) -> Generator[str, None, None]: with open(path, "rb") as f: stat_result = os.fstat(f.fileno()) data = f.read() @@ -95,7 +107,7 @@ def compress(self, path): yield self.write_data(path, compressed, ".gz", stat_result) @staticmethod - def compress_gzip(data): + def compress_gzip(data: bytes) -> bytes: output = BytesIO() # Explicitly set mtime to 0 so gzip content is fully determined # by file content (0 = "no timestamp" according to gzip spec) @@ -106,10 +118,13 @@ def compress_gzip(data): return output.getvalue() @staticmethod - def compress_brotli(data): - return brotli.compress(data) + def compress_brotli(data: bytes) -> bytes: + result: bytes = brotli.compress(data) + return result - def is_compressed_effectively(self, encoding_name, path, orig_size, data): + def is_compressed_effectively( + self, encoding_name: str, path: str, orig_size: int, data: bytes + ) -> bool: compressed_size = len(data) if orig_size == 0: is_effective = False @@ -126,7 +141,9 @@ def is_compressed_effectively(self, encoding_name, path, orig_size, data): self.log(f"Skipping {path} ({encoding_name} compression not effective)") return is_effective - def write_data(self, path, data, suffix, stat_result): + def write_data( + self, path: str, data: bytes, suffix: str, stat_result: os.stat_result + ) -> str: filename = path + suffix with open(filename, "wb") as f: f.write(data) @@ -134,7 +151,7 @@ def write_data(self, path, data, suffix, stat_result): return filename -def main(argv=None): +def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser( description="Search for all files inside *not* matching " " and produce compressed versions with " diff --git a/src/whitenoise/middleware.py b/src/whitenoise/middleware.py index 3f5a8091..e474e265 100644 --- a/src/whitenoise/middleware.py +++ b/src/whitenoise/middleware.py @@ -1,16 +1,23 @@ from __future__ import annotations import os +from io import BytesIO from posixpath import basename +from typing import Callable +from typing import Generator from urllib.parse import urlparse from django.conf import settings from django.contrib.staticfiles import finders from django.contrib.staticfiles.storage import staticfiles_storage from django.http import FileResponse +from django.http import HttpRequest +from django.http.response import HttpResponseBase from django.urls import get_script_prefix from .base import WhiteNoise +from .responders import Redirect +from .responders import StaticFile from .string_utils import ensure_leading_trailing_slash __all__ = ["WhiteNoiseMiddleware"] @@ -24,7 +31,7 @@ class WhiteNoiseFileResponse(FileResponse): are actively harmful. """ - def set_headers(self, *args, **kwargs): + def set_headers(self, filelike: BytesIO) -> None: pass @@ -34,7 +41,7 @@ class WhiteNoiseMiddleware(WhiteNoise): than WSGI middleware. """ - def __init__(self, get_response=None, settings=settings): + def __init__(self, get_response: Callable[[HttpRequest], HttpResponseBase]) -> None: self.get_response = get_response try: @@ -114,7 +121,7 @@ def __init__(self, get_response=None, settings=settings): if self.use_finders and not self.autorefresh: self.add_files_from_finders() - def __call__(self, request): + def __call__(self, request: HttpRequest) -> HttpResponseBase: if self.autorefresh: static_file = self.find_file(request.path_info) else: @@ -124,8 +131,10 @@ def __call__(self, request): return self.get_response(request) @staticmethod - def serve(static_file, request): - response = static_file.get_response(request.method, request.META) + def serve( + static_file: Redirect | StaticFile, request: HttpRequest + ) -> WhiteNoiseFileResponse: + response = static_file.get_response(request.method or "GET", request.META) status = int(response.status) http_response = WhiteNoiseFileResponse(response.file or (), status=status) # Remove default content-type @@ -134,8 +143,8 @@ def serve(static_file, request): http_response[key] = value return http_response - def add_files_from_finders(self): - files = {} + def add_files_from_finders(self) -> None: + files: dict[str, str] = {} for finder in finders.get_finders(): for path, storage in finder.list(None): prefix = (getattr(storage, "prefix", None) or "").strip("/") @@ -153,7 +162,7 @@ def add_files_from_finders(self): for url, path in files.items(): self.add_file_to_dictionary(url, path, stat_cache=stat_cache) - def candidate_paths_for_url(self, url): + def candidate_paths_for_url(self, url: str) -> Generator[str, None, None]: if self.use_finders and url.startswith(self.static_prefix): path = finders.find(url[len(self.static_prefix) :]) if path: @@ -162,7 +171,7 @@ def candidate_paths_for_url(self, url): for path in paths: yield path - def immutable_file_test(self, path, url): + def immutable_file_test(self, path: str, url: str) -> bool: """ Determine whether given URL represents an immutable file (i.e. a file with a hash of its contents as part of its name) which can @@ -182,7 +191,7 @@ def immutable_file_test(self, path, url): return True return False - def get_name_without_hash(self, filename): + def get_name_without_hash(self, filename: str) -> str: """ Removes the version hash from a filename e.g, transforms 'css/application.f3ea4bcc2.css' into 'css/application.css' @@ -195,7 +204,7 @@ def get_name_without_hash(self, filename): name = os.path.splitext(name_with_hash)[0] return name + ext - def get_static_url(self, name): + def get_static_url(self, name: str) -> str | None: try: return staticfiles_storage.url(name) except ValueError: diff --git a/src/whitenoise/responders.py b/src/whitenoise/responders.py index 9501ea65..8c40e9e4 100644 --- a/src/whitenoise/responders.py +++ b/src/whitenoise/responders.py @@ -9,6 +9,9 @@ from http import HTTPStatus from io import BufferedIOBase from time import mktime +from typing import Callable +from typing import Pattern +from typing import Sequence from urllib.parse import quote from wsgiref.headers import Headers @@ -16,7 +19,12 @@ class Response: __slots__ = ("status", "headers", "file") - def __init__(self, status, headers, file): + def __init__( + self, + status: HTTPStatus, + headers: Sequence[tuple[str, str]], + file: BufferedIOBase | None, + ) -> None: self.status = status self.headers = headers self.file = file @@ -47,15 +55,15 @@ class SlicedFile(BufferedIOBase): been reached. """ - def __init__(self, fileobj, start, end): + def __init__(self, fileobj: BufferedIOBase, start: int, end: int) -> None: fileobj.seek(start) self.fileobj = fileobj self.remaining = end - start + 1 - def read(self, size=-1): + def read(self, size: int | None = -1) -> bytes: if self.remaining <= 0: return b"" - if size < 0: + if size is None or size < 0: size = self.remaining else: size = min(size, self.remaining) @@ -63,20 +71,26 @@ def read(self, size=-1): self.remaining -= len(data) return data - def close(self): + def close(self) -> None: self.fileobj.close() class StaticFile: - def __init__(self, path, headers, encodings=None, stat_cache=None): + def __init__( + self, + path: str, + headers: list[tuple[str, str]], + encodings: dict[str, str] | None = None, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> None: files = self.get_file_stats(path, encodings, stat_cache) - headers = self.get_headers(headers, files) - self.last_modified = parsedate(headers["Last-Modified"]) - self.etag = headers["ETag"] - self.not_modified_response = self.get_not_modified_response(headers) - self.alternatives = self.get_alternatives(headers, files) + parsed_headers = self.get_headers(headers, files) + self.last_modified = parsedate(parsed_headers["Last-Modified"]) + self.etag = parsed_headers["ETag"] + self.not_modified_response = self.get_not_modified_response(parsed_headers) + self.alternatives = self.get_alternatives(parsed_headers, files) - def get_response(self, method, request_headers): + def get_response(self, method: str, request_headers: dict[str, str]) -> Response: if method not in ("GET", "HEAD"): return NOT_ALLOWED_RESPONSE if self.is_not_modified(request_headers): @@ -87,7 +101,7 @@ def get_response(self, method, request_headers): else: file_handle = None range_header = request_headers.get("HTTP_RANGE") - if range_header: + if range_header is not None: try: return self.get_range_response(range_header, headers, file_handle) except ValueError: @@ -97,7 +111,12 @@ def get_response(self, method, request_headers): pass return Response(HTTPStatus.OK, headers, file_handle) - def get_range_response(self, range_header, base_headers, file_handle): + def get_range_response( + self, + range_header: str, + base_headers: list[tuple[str, str]], + file_handle: BufferedIOBase | None, + ) -> Response: headers = [] for item in base_headers: if item[0] == "Content-Length": @@ -113,7 +132,7 @@ def get_range_response(self, range_header, base_headers, file_handle): headers.append(("Content-Length", str(end - start + 1))) return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle) - def get_byte_range(self, range_header, size): + def get_byte_range(self, range_header: str, size: int) -> tuple[int, int]: start, end = self.parse_byte_range(range_header) if start < 0: start = max(start + size, 0) @@ -124,7 +143,7 @@ def get_byte_range(self, range_header, size): return start, end @staticmethod - def parse_byte_range(range_header): + def parse_byte_range(range_header: str) -> tuple[int, int | None]: units, _, range_spec = range_header.strip().partition("=") if units != "bytes": raise ValueError() @@ -142,7 +161,9 @@ def parse_byte_range(range_header): return start, end @staticmethod - def get_range_not_satisfiable_response(file_handle, size): + def get_range_not_satisfiable_response( + file_handle: BufferedIOBase | None, size: int + ) -> Response: if file_handle is not None: file_handle.close() return Response( @@ -152,9 +173,13 @@ def get_range_not_satisfiable_response(file_handle, size): ) @staticmethod - def get_file_stats(path, encodings, stat_cache): + def get_file_stats( + path: str, + encodings: dict[str, str] | None, + stat_cache: dict[str, os.stat_result] | None, + ) -> dict[str | None, FileEntry]: # Primary file has an encoding of None - files = {None: FileEntry(path, stat_cache)} + files: dict[str | None, FileEntry] = {None: FileEntry(path, stat_cache)} if encodings: for encoding, alt_path in encodings.items(): try: @@ -163,7 +188,9 @@ def get_file_stats(path, encodings, stat_cache): continue return files - def get_headers(self, headers_list, files): + def get_headers( + self, headers_list: list[tuple[str, str]], files: dict[str | None, FileEntry] + ) -> Headers: headers = Headers(headers_list) main_file = files[None] if len(files) > 1: @@ -182,17 +209,21 @@ def get_headers(self, headers_list, files): return headers @staticmethod - def get_not_modified_response(headers): - not_modified_headers = [] + def get_not_modified_response(headers: Headers) -> Response: + not_modified_headers: list[tuple[str, str]] = [] for key in NOT_MODIFIED_HEADERS: if key in headers: - not_modified_headers.append((key, headers[key])) + value = headers[key] + assert value is not None + not_modified_headers.append((key, value)) return Response( status=HTTPStatus.NOT_MODIFIED, headers=not_modified_headers, file=None ) @staticmethod - def get_alternatives(base_headers, files): + def get_alternatives( + base_headers: Headers, files: dict[str | None, FileEntry] + ) -> list[tuple[Pattern[str], str, list[tuple[str, str]]]]: # Sort by size so that the smallest compressed alternative matches first alternatives = [] files_by_size = sorted(files.items(), key=lambda i: i[1].size) @@ -207,7 +238,7 @@ def get_alternatives(base_headers, files): alternatives.append((encoding_re, file_entry.path, headers.items())) return alternatives - def is_not_modified(self, request_headers): + def is_not_modified(self, request_headers: dict[str, str]) -> bool: previous_etag = request_headers.get("HTTP_IF_NONE_MATCH") if previous_etag is not None: return previous_etag == self.etag @@ -222,7 +253,9 @@ def is_not_modified(self, request_headers): return last_requested_ts >= self.last_modified return False - def get_path_and_headers(self, request_headers): + def get_path_and_headers( + self, request_headers: dict[str, str] + ) -> tuple[str, list[tuple[str, str]]]: accept_encoding = request_headers.get("HTTP_ACCEPT_ENCODING", "") if accept_encoding == "*": accept_encoding = "" @@ -230,15 +263,19 @@ def get_path_and_headers(self, request_headers): for encoding_re, path, headers in self.alternatives: if encoding_re.search(accept_encoding): return path, headers + raise AssertionError("Unreachable") class Redirect: - def __init__(self, location, headers=None): - headers = list(headers.items()) if headers else [] - headers.append(("Location", quote(location.encode("utf8")))) - self.response = Response(HTTPStatus.FOUND, headers, None) + def __init__(self, location: str, headers: dict[str, str] | None = None) -> None: + if headers is None: + header_list = [] + else: + header_list = list(headers.items()) + header_list.append(("Location", quote(location.encode("utf8")))) + self.response = Response(HTTPStatus.FOUND, header_list, None) - def get_response(self, method, request_headers): + def get_response(self, method: str, request_headers: dict[str, str]) -> Response: return self.response @@ -257,15 +294,23 @@ class IsDirectoryError(MissingFileError): class FileEntry: __slots__ = ("path", "size", "mtime") - def __init__(self, path, stat_cache=None): + def __init__( + self, path: str, stat_cache: dict[str, os.stat_result] | None = None + ) -> None: self.path = path - stat_function = os.stat if stat_cache is None else stat_cache.__getitem__ + stat_function: Callable[[str], os.stat_result] + if stat_cache is None: + stat_function = os.stat + else: + stat_function = stat_cache.__getitem__ stat = self.stat_regular_file(path, stat_function) self.size = stat.st_size self.mtime = stat.st_mtime @staticmethod - def stat_regular_file(path, stat_function): + def stat_regular_file( + path: str, stat_function: Callable[[str], os.stat_result] + ) -> os.stat_result: """ Wrap `stat_function` to raise appropriate errors if `path` is not a regular file diff --git a/src/whitenoise/runserver_nostatic/management/commands/runserver.py b/src/whitenoise/runserver_nostatic/management/commands/runserver.py index 484b1442..83f0f65f 100644 --- a/src/whitenoise/runserver_nostatic/management/commands/runserver.py +++ b/src/whitenoise/runserver_nostatic/management/commands/runserver.py @@ -8,24 +8,29 @@ """ from __future__ import annotations +import argparse from importlib import import_module +from typing import Generator from django.apps import apps +from django.core.management import BaseCommand -def get_next_runserver_command(): +def get_next_runserver_command() -> type[BaseCommand]: """ Return the next highest priority "runserver" command class """ for app_name in get_lower_priority_apps(): module_path = "%s.management.commands.runserver" % app_name try: - return import_module(module_path).Command + klass: type[BaseCommand] = import_module(module_path).Command + return klass except (ImportError, AttributeError): pass + raise ValueError("No lower priority app has a 'runserver' command") -def get_lower_priority_apps(): +def get_lower_priority_apps() -> Generator[str, None, None]: """ Yield all app module names below the current app in the INSTALLED_APPS list """ @@ -42,11 +47,12 @@ def get_lower_priority_apps(): RunserverCommand = get_next_runserver_command() -class Command(RunserverCommand): - def add_arguments(self, parser): +class Command(RunserverCommand): # type: ignore [misc,valid-type] + def add_arguments(self, parser: argparse.ArgumentParser) -> None: super().add_arguments(parser) if parser.get_default("use_static_handler") is True: parser.set_defaults(use_static_handler=False) + assert parser.description is not None parser.description += ( "\n(Wrapped by 'whitenoise.runserver_nostatic' to always" " enable '--nostatic')" diff --git a/src/whitenoise/storage.py b/src/whitenoise/storage.py index 029b0ff8..a46e209b 100644 --- a/src/whitenoise/storage.py +++ b/src/whitenoise/storage.py @@ -55,18 +55,19 @@ class CompressedManifestStaticFilesStorage(ManifestStaticFilesStorage): those without the hash in their name) """ - _new_files = None - - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: manifest_strict = getattr(settings, "WHITENOISE_MANIFEST_STRICT", None) if manifest_strict is not None: self.manifest_strict = manifest_strict super().__init__(*args, **kwargs) + self._new_files: set[str] | None = None - def post_process(self, *args, **kwargs): - files = super().post_process(*args, **kwargs) + def post_process( + self, paths: dict[str, Any], dry_run: bool = False, **options: Any + ) -> _PostProcessT: + files = super().post_process(paths, dry_run=dry_run, **options) - if not kwargs.get("dry_run"): + if not dry_run: files = self.post_process_with_compression(files) # Make exception messages helpful @@ -75,14 +76,14 @@ def post_process(self, *args, **kwargs): processed = self.make_helpful_exception(processed, name) yield name, hashed_name, processed - def post_process_with_compression(self, files): + def post_process_with_compression(self, files: _PostProcessT) -> _PostProcessT: # Files may get hashed multiple times, we want to keep track of all the # intermediate files generated during the process and which of these # are the final names used for each file. As not every intermediate # file is yielded we have to hook in to the `hashed_name` method to # keep track of them all. hashed_names = {} - new_files = set() + new_files: set[str] = set() self.start_tracking_new_files(new_files) for name, hashed_name, processed in files: if hashed_name and not isinstance(processed, Exception): @@ -107,17 +108,17 @@ def hashed_name(self, *args, **kwargs): self._new_files.add(self.clean_name(name)) return name - def start_tracking_new_files(self, new_files): + def start_tracking_new_files(self, new_files: set[str]) -> None: self._new_files = new_files - def stop_tracking_new_files(self): + def stop_tracking_new_files(self) -> None: self._new_files = None @property - def keep_only_hashed_files(self): + def keep_only_hashed_files(self) -> bool: return getattr(settings, "WHITENOISE_KEEP_ONLY_HASHED_FILES", False) - def delete_files(self, files_to_delete): + def delete_files(self, files_to_delete) -> None: for name in files_to_delete: try: os.unlink(self.path(name)) @@ -125,7 +126,7 @@ def delete_files(self, files_to_delete): if e.errno != errno.ENOENT: raise - def create_compressor(self, **kwargs): + def create_compressor(self, **kwargs) -> Compressor: return Compressor(**kwargs) def compress_files(self, names): @@ -139,7 +140,7 @@ def compress_files(self, names): compressed_name = compressed_path[prefix_len:] yield name, compressed_name - def make_helpful_exception(self, exception, name): + def make_helpful_exception(self, exception: Exception, name: str) -> Exception: """ If a CSS file contains references to images, fonts etc that can't be found then Django's `post_process` blows up with a not particularly helpful diff --git a/src/whitenoise/string_utils.py b/src/whitenoise/string_utils.py index 6be90620..f26df1cf 100644 --- a/src/whitenoise/string_utils.py +++ b/src/whitenoise/string_utils.py @@ -4,10 +4,10 @@ # Follow Django in treating URLs as UTF-8 encoded (which requires undoing the # implicit ISO-8859-1 decoding applied in Python 3). Strictly speaking, URLs # should only be ASCII anyway, but UTF-8 can be found in the wild. -def decode_path_info(path_info): +def decode_path_info(path_info: str) -> str: return path_info.encode("iso-8859-1", "replace").decode("utf-8", "replace") -def ensure_leading_trailing_slash(path): +def ensure_leading_trailing_slash(path: str | None) -> str: path = (path or "").strip("/") return f"/{path}/" if path else "/" diff --git a/tests/django_urls.py b/tests/django_urls.py index 562b8d6f..a39253fb 100644 --- a/tests/django_urls.py +++ b/tests/django_urls.py @@ -1,3 +1,5 @@ from __future__ import annotations -urlpatterns = [] +from django.urls import URLPattern + +urlpatterns: list[URLPattern] = [] diff --git a/tests/test_compress.py b/tests/test_compress.py index 11ea9471..b483c1df 100644 --- a/tests/test_compress.py +++ b/tests/test_compress.py @@ -70,8 +70,12 @@ def test_with_falsey_extensions(): def test_custom_log(): - compressor = Compressor(log="test") - assert compressor.log == "test" + def logit(message: str) -> None: + pass + + compressor = Compressor(log=logit) + + assert compressor.log is logit def test_compress(): @@ -82,5 +86,5 @@ def test_compress(): def test_compressed_effectively_no_orig_size(): compressor = Compressor(quiet=True) assert not compressor.is_compressed_effectively( - "test_encoding", "test_path", 0, "test_data" + "test_encoding", "test_path", 0, b"test_data" ) diff --git a/tests/test_django_whitenoise.py b/tests/test_django_whitenoise.py index 2654424f..8c5ab108 100644 --- a/tests/test_django_whitenoise.py +++ b/tests/test_django_whitenoise.py @@ -3,6 +3,7 @@ import shutil import tempfile from contextlib import closing +from typing import Any from urllib.parse import urljoin from urllib.parse import urlparse @@ -21,11 +22,11 @@ from whitenoise.middleware import WhiteNoiseMiddleware -def reset_lazy_object(obj): +def reset_lazy_object(obj: Any) -> None: obj._wrapped = empty -def get_url_path(base, url): +def get_url_path(base: str, url: str) -> str: return urlparse(urljoin(base, url)).path @@ -130,7 +131,8 @@ def finder_static_files(request): WHITENOISE_INDEX_FILE=True, STATIC_ROOT=None, ): - finders.get_finder.cache_clear() + # django-stubs doesn’t mark get_finder() as @lru_cache + finders.get_finder.cache_clear() # type: ignore [attr-defined] yield files diff --git a/tests/test_runserver_nostatic.py b/tests/test_runserver_nostatic.py index 65ed458b..04f026ff 100644 --- a/tests/test_runserver_nostatic.py +++ b/tests/test_runserver_nostatic.py @@ -1,10 +1,11 @@ from __future__ import annotations +from django.core.management import BaseCommand from django.core.management import get_commands from django.core.management import load_command_class -def get_command_instance(name): +def get_command_instance(name: str) -> BaseCommand: app_name = get_commands()[name] return load_command_class(app_name, name) diff --git a/tests/test_storage.py b/tests/test_storage.py index fca9a60f..8a0f74c0 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -5,6 +5,7 @@ import shutil import tempfile from posixpath import basename +from typing import Any import django import pytest @@ -22,7 +23,8 @@ @pytest.fixture() def setup(): - staticfiles_storage._wrapped = empty + # stubs don't mark staticfiles_storage as a lazy object with _wrapped + staticfiles_storage._wrapped = empty # type: ignore [attr-defined] files = Files("static") tmp = tempfile.mkdtemp() with override_settings( @@ -30,7 +32,7 @@ def setup(): STATIC_ROOT=tmp, ): yield settings - staticfiles_storage._wrapped = empty + staticfiles_storage._wrapped = empty # type: ignore [attr-defined] shutil.rmtree(tmp) @@ -38,16 +40,16 @@ def setup(): def _compressed_storage(setup): backend = "whitenoise.storage.CompressedStaticFilesStorage" if django.VERSION >= (4, 2): - storages = { + overrides: dict[str, Any] = { "STORAGES": { **settings.STORAGES, "staticfiles": {"BACKEND": backend}, } } else: - storages = {"STATICFILES_STORAGE": backend} + overrides = {"STATICFILES_STORAGE": backend} - with override_settings(**storages): + with override_settings(**overrides): yield @@ -55,16 +57,16 @@ def _compressed_storage(setup): def _compressed_manifest_storage(setup): backend = "whitenoise.storage.CompressedManifestStaticFilesStorage" if django.VERSION >= (4, 2): - storages = { + overrides: dict[str, Any] = { "STORAGES": { **settings.STORAGES, "staticfiles": {"BACKEND": backend}, } } else: - storages = {"STATICFILES_STORAGE": backend} + overrides = {"STATICFILES_STORAGE": backend} - with override_settings(**storages, WHITENOISE_KEEP_ONLY_HASHED_FILES=True): + with override_settings(**overrides, WHITENOISE_KEEP_ONLY_HASHED_FILES=True): call_command("collectstatic", verbosity=0, interactive=False) @@ -89,7 +91,7 @@ class TriggerException(HashedFilesMixin): def exists(self, path): return False - exception = None + exception: ValueError try: TriggerException().hashed_name("/missing/file.png") except ValueError as e: diff --git a/tests/test_whitenoise.py b/tests/test_whitenoise.py index 3e98f030..661cba10 100644 --- a/tests/test_whitenoise.py +++ b/tests/test_whitenoise.py @@ -8,14 +8,17 @@ import tempfile import warnings from contextlib import closing +from typing import Any from urllib.parse import urljoin from wsgiref.headers import Headers from wsgiref.simple_server import demo_app import pytest +import requests from .utils import AppServer from .utils import Files +from .utils import hello_world_app from whitenoise import WhiteNoise from whitenoise.responders import StaticFile @@ -48,7 +51,7 @@ def application(request, files): yield _init_application(files.directory) -def _init_application(directory, **kwargs): +def _init_application(directory: str, **kwargs: Any) -> WhiteNoise: def custom_headers(headers, path, url): if url.endswith(".css"): headers["X-Is-Css-File"] = "True" @@ -71,7 +74,7 @@ def server(application): yield app_server -def assert_is_default_response(response): +def assert_is_default_response(response: requests.Response) -> None: assert "Hello world!" in response.text @@ -310,7 +313,7 @@ def test_no_error_on_very_long_filename(server): assert response.status_code != 500 -def copytree(src, dst): +def copytree(src: str, dst: str) -> None: for name in os.listdir(src): src_path = os.path.join(src, name) dst_path = os.path.join(dst, name) @@ -321,7 +324,7 @@ def copytree(src, dst): def test_immutable_file_test_accepts_regex(): - instance = WhiteNoise(None, immutable_file_test=r"\.test$") + instance = WhiteNoise(hello_world_app, immutable_file_test=r"\.test$") assert instance.immutable_file_test("", "/myfile.test") assert not instance.immutable_file_test("", "file.test.txt") @@ -332,7 +335,7 @@ def test_directory_path_can_be_pathlib_instance(): root = Path(Files("root").directory) # Check we can construct instance without it blowing up - WhiteNoise(None, root=root, autorefresh=True) + WhiteNoise(hello_world_app, root=root, autorefresh=True) def fake_stat_entry( @@ -358,6 +361,7 @@ def test_last_modified_not_set_when_mtime_is_zero(): stat_cache = {__file__: fake_stat_entry()} responder = StaticFile(__file__, [], stat_cache=stat_cache) response = responder.get_response("GET", {}) + assert response.file is not None response.file.close() headers_dict = Headers(response.headers) assert "Last-Modified" not in headers_dict @@ -368,6 +372,7 @@ def test_file_size_matches_range_with_range_header(): stat_cache = {__file__: fake_stat_entry()} responder = StaticFile(__file__, [], stat_cache=stat_cache) response = responder.get_response("GET", {"HTTP_RANGE": "bytes=0-13"}) + assert response.file is not None file_size = len(response.file.read()) assert file_size == 14 diff --git a/tests/utils.py b/tests/utils.py index 0db45d75..76c1c1bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,17 +2,23 @@ import os import threading +from typing import Any +from typing import Iterable from wsgiref.simple_server import make_server from wsgiref.simple_server import WSGIRequestHandler from wsgiref.util import shift_path_info import requests +from whitenoise.compat import StartResponse +from whitenoise.compat import WSGIApplication +from whitenoise.compat import WSGIEnvironment + TEST_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_files") class SilentWSGIHandler(WSGIRequestHandler): - def log_message(*args): + def log_message(self, format: str, *args: Any) -> None: pass @@ -24,13 +30,15 @@ class AppServer: PREFIX = "subdir" - def __init__(self, application): + def __init__(self, application: WSGIApplication) -> None: self.application = application self.server = make_server( "127.0.0.1", 0, self.serve_under_prefix, handler_class=SilentWSGIHandler ) - def serve_under_prefix(self, environ, start_response): + def serve_under_prefix( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: prefix = shift_path_info(environ) if prefix != self.PREFIX: start_response("404 Not Found", []) @@ -38,10 +46,12 @@ def serve_under_prefix(self, environ, start_response): else: return self.application(environ, start_response) - def get(self, *args, **kwargs): + def get(self, *args: Any, **kwargs: Any) -> requests.Response: return self.request("get", *args, **kwargs) - def request(self, method, path, *args, **kwargs): + def request( + self, method: str, path: str, *args: Any, **kwargs: Any + ) -> requests.Response: url = "http://{0[0]}:{0[1]}{1}".format(self.server.server_address, path) thread = threading.Thread(target=self.server.handle_request) thread.start() @@ -49,12 +59,12 @@ def request(self, method, path, *args, **kwargs): thread.join() return response - def close(self): + def close(self) -> None: self.server.server_close() class Files: - def __init__(self, directory, **files): + def __init__(self, directory: str, **files: str) -> None: self.directory = os.path.join(TEST_FILE_PATH, directory) for name, path in files.items(): url = f"/{AppServer.PREFIX}/{path}" @@ -63,3 +73,10 @@ def __init__(self, directory, **files): setattr(self, name + "_path", path) setattr(self, name + "_url", url) setattr(self, name + "_content", content) + + +def hello_world_app( + environ: WSGIEnvironment, start_response: StartResponse +) -> Iterable[bytes]: + start_response("200 OK", [("Content-type", "text/plain; charset=utf-8")]) + return [b"Hello World"]