diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 299f0d36..884ebbea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,3 +47,11 @@ repos: - flake8-comprehensions - flake8-tidy-imports - flake8-typing-imports +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.990 + hooks: + - id: mypy + additional_dependencies: + - django-stubs==1.12.0 + - requests + - types-requests diff --git a/pyproject.toml b/pyproject.toml index 99efa39d..16d0e470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,21 @@ build-backend = "setuptools.build_meta" [tool.black] target-version = ['py37'] +[tool.django-stubs] +django_settings_module = "tests.django_settings" + +[tool.mypy] +mypy_path = "src/" +namespace_packages = false +plugins = ["mypy_django_plugin.main"] +show_error_codes = true +strict = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "tests.*" +allow_untyped_defs = true + [tool.pytest.ini_options] addopts = """\ --strict-config diff --git a/requirements/py37-django32.txt b/requirements/py37-django32.txt index a5ed4a0f..ff689ec5 100644 --- a/requirements/py37-django32.txt +++ b/requirements/py37-django32.txt @@ -49,8 +49,9 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest -typing-extensions==4.4.0 +typing-extensions==4.4.0 ; python_version < "3.10" # via + # -r requirements.in # asgiref # importlib-metadata urllib3==1.26.12 diff --git a/requirements/py38-django32.txt b/requirements/py38-django32.txt index 1bd9208c..0d22c632 100644 --- a/requirements/py38-django32.txt +++ b/requirements/py38-django32.txt @@ -46,6 +46,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/py38-django40.txt b/requirements/py38-django40.txt index d3a214a6..8a6474c7 100644 --- a/requirements/py38-django40.txt +++ b/requirements/py38-django40.txt @@ -46,6 +46,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/py38-django41.txt b/requirements/py38-django41.txt index 9d63e4ae..e1f37316 100644 --- a/requirements/py38-django41.txt +++ b/requirements/py38-django41.txt @@ -46,6 +46,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/py39-django32.txt b/requirements/py39-django32.txt index d5b83594..ffc6f5e8 100644 --- a/requirements/py39-django32.txt +++ b/requirements/py39-django32.txt @@ -46,6 +46,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/py39-django40.txt b/requirements/py39-django40.txt index fd93fb9a..6f6f0287 100644 --- a/requirements/py39-django40.txt +++ b/requirements/py39-django40.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/py39-django41.txt b/requirements/py39-django41.txt index 8ee7a081..147a7621 100644 --- a/requirements/py39-django41.txt +++ b/requirements/py39-django41.txt @@ -44,6 +44,8 @@ sqlparse==0.4.3 # via django tomli==2.0.1 # via pytest +typing-extensions==4.4.0 ; python_version < "3.10" + # via -r requirements.in urllib3==1.26.12 # via requests zipp==3.10.0 diff --git a/requirements/requirements.in b/requirements/requirements.in index 11b521d4..e97f34ed 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -4,3 +4,4 @@ django pytest pytest-randomly requests +typing-extensions ; python_version < "3.10" diff --git a/setup.cfg b/setup.cfg index f08093fb..375acd86 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,8 @@ package_dir= =src packages = find: include_package_data = True +install_requires = + typing-extensions ; python_version < "3.10" python_requires = >=3.7 zip_safe = False diff --git a/src/whitenoise/base.py b/src/whitenoise/base.py index d9063fa9..81639a65 100644 --- a/src/whitenoise/base.py +++ b/src/whitenoise/base.py @@ -4,10 +4,11 @@ import re import warnings from posixpath import normpath -from typing import Callable +from typing import Callable, Generator, Iterable from wsgiref.headers import Headers from wsgiref.util import FileWrapper +from .compat import StartResponse, WSGIApplication, WSGIEnvironment from .media_types import MediaTypes from .responders import IsDirectoryError from .responders import MissingFileError @@ -26,9 +27,9 @@ class WhiteNoise: def __init__( self, - application, - root=None, - prefix=None, + application: WSGIApplication, + root: str | os.PathLike[str] | None = None, + prefix: str | None = None, *, # Re-check the filesystem on every request so that any changes are # automatically picked up. NOTE: For use in development only, not supported @@ -45,7 +46,7 @@ def __init__( mimetypes: dict[str, str] | None = None, add_headers_function: Callable[[Headers, str, str], None] | None = None, index_file: str | bool | None = None, - immutable_file_test: Callable | str | None = None, + immutable_file_test: Callable[[str, str], bool] | str | None = None, ): self.autorefresh = autorefresh self.max_age = max_age @@ -68,12 +69,14 @@ def __init__( self.media_types = MediaTypes(extra_types=mimetypes) self.application = application - self.files = {} - self.directories = [] + self.files: dict[str, Redirect | StaticFile] = {} + self.directories: list[tuple[str, str]] = [] if root is not None: self.add_files(root, prefix) - def __call__(self, environ, start_response): + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: path = decode_path_info(environ.get("PATH_INFO", "")) if self.autorefresh: static_file = self.find_file(path) @@ -85,17 +88,23 @@ def __call__(self, environ, start_response): return self.serve(static_file, environ, start_response) @staticmethod - def serve(static_file, environ, start_response): + def serve( + static_file: Redirect | StaticFile, + environ: WSGIEnvironment, + start_response: StartResponse, + ) -> Iterable[bytes]: response = static_file.get_response(environ["REQUEST_METHOD"], environ) status_line = f"{response.status} {response.status.phrase}" start_response(status_line, list(response.headers)) if response.file is not None: - file_wrapper = environ.get("wsgi.file_wrapper", FileWrapper) + file_wrapper: type[FileWrapper] = environ.get( + "wsgi.file_wrapper", FileWrapper + ) return file_wrapper(response.file) else: return [] - def add_files(self, root, prefix=None): + def add_files(self, root: str, prefix: str | None = None) -> None: root = decode_if_byte_string(root, force_text=True) root = os.path.abspath(root) root = root.rstrip(os.path.sep) + os.path.sep @@ -112,7 +121,7 @@ def add_files(self, root, prefix=None): else: warnings.warn(f"No directory at: {root}") - def update_files_dictionary(self, root, prefix): + def update_files_dictionary(self, root: str, prefix: str) -> None: # Build a mapping from paths to the results of `os.stat` calls # so we only have to touch the filesystem once stat_cache = dict(scantree(root)) @@ -122,7 +131,12 @@ def update_files_dictionary(self, root, prefix): url = prefix + relative_url self.add_file_to_dictionary(url, path, stat_cache=stat_cache) - def add_file_to_dictionary(self, url, path, stat_cache=None): + def add_file_to_dictionary( + self, + url: str, + path: str, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> None: if self.is_compressed_variant(path, stat_cache=stat_cache): return if self.index_file is not None and url.endswith("/" + self.index_file): @@ -134,26 +148,27 @@ def add_file_to_dictionary(self, url, path, stat_cache=None): static_file = self.get_static_file(path, url, stat_cache=stat_cache) self.files[url] = static_file - def find_file(self, url): + def find_file(self, url: str) -> Redirect | StaticFile | None: # Optimization: bail early if the URL can never match a file if self.index_file is None and url.endswith("/"): return if not self.url_is_canonical(url): - return + return None for path in self.candidate_paths_for_url(url): try: return self.find_file_at_path(path, url) except MissingFileError: pass + return None - def candidate_paths_for_url(self, url): + def candidate_paths_for_url(self, url: str) -> Generator[str, None, None]: for root, prefix in self.directories: if url.startswith(prefix): path = os.path.join(root, url[len(prefix) :]) if os.path.commonprefix((root, path)) == root: yield path - def find_file_at_path(self, path, url): + def find_file_at_path(self, path: str, url: str) -> Redirect | StaticFile: if self.is_compressed_variant(path): raise MissingFileError(path) @@ -175,7 +190,7 @@ def find_file_at_path(self, path, url): return self.get_static_file(path, url) @staticmethod - def url_is_canonical(url): + def url_is_canonical(url: str) -> bool: """ Check that the URL path is in canonical format i.e. has normalised slashes and no path traversal elements @@ -188,7 +203,9 @@ def url_is_canonical(url): return normalised == url @staticmethod - def is_compressed_variant(path, stat_cache=None): + def is_compressed_variant( + path: str, stat_cache: dict[str, os.stat_result] | None = None + ) -> bool: if path[-3:] in (".gz", ".br"): uncompressed_path = path[:-3] if stat_cache is None: @@ -197,7 +214,12 @@ def is_compressed_variant(path, stat_cache=None): return uncompressed_path in stat_cache return False - def get_static_file(self, path, url, stat_cache=None): + def get_static_file( + self, + path: str, + url: str, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> StaticFile: # Optimization: bail early if file does not exist if stat_cache is None and not os.path.exists(path): raise MissingFileError(path) @@ -215,7 +237,7 @@ def get_static_file(self, path, url, stat_cache=None): encodings={"gzip": path + ".gz", "br": path + ".br"}, ) - def add_mime_headers(self, headers, path, url): + def add_mime_headers(self, headers: Headers, path: str, url: str) -> None: media_type = self.media_types.get_type(path) if media_type.startswith("text/"): params = {"charset": str(self.charset)} @@ -223,7 +245,7 @@ def add_mime_headers(self, headers, path, url): params = {} headers.add_header("Content-Type", str(media_type), **params) - def add_cache_headers(self, headers, path, url): + def add_cache_headers(self, headers: Headers, path: str, url: str) -> None: if self.immutable_file_test(path, url): headers["Cache-Control"] = "max-age={}, public, immutable".format( self.FOREVER @@ -231,14 +253,14 @@ def add_cache_headers(self, headers, path, url): elif self.max_age is not None: headers["Cache-Control"] = f"max-age={self.max_age}, public" - def immutable_file_test(self, path, url): + def immutable_file_test(self, path: str, url: str) -> bool: """ This should be implemented by sub-classes (see e.g. WhiteNoiseMiddleware) or by setting the `immutable_file_test` config option """ return False - def redirect(self, from_url, to_url): + def redirect(self, from_url: str, to_url: str) -> Redirect: """ Return a relative 302 redirect @@ -258,7 +280,7 @@ def redirect(self, from_url, to_url): return Redirect(relative_url, headers=headers) -def scantree(root): +def scantree(root: str) -> Generator[tuple[str, os.stat_result], None, None]: """ Recurse the given directory yielding (pathname, os.stat(pathname)) pairs """ diff --git a/src/whitenoise/compat.py b/src/whitenoise/compat.py new file mode 100644 index 00000000..9c8f736c --- /dev/null +++ b/src/whitenoise/compat.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 11): + from wsgiref.types import WSGIApplication +else: + from collections.abc import Callable, Iterable, Iterator + from types import TracebackType + from typing import Any, Dict, Protocol, Tuple, Type, Union + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + __all__ = [ + "StartResponse", + "WSGIEnvironment", + "WSGIApplication", + "InputStream", + "ErrorStream", + "FileWrapper", + ] + + _ExcInfo: TypeAlias = Tuple[Type[BaseException], BaseException, TracebackType] + _OptExcInfo: TypeAlias = Union[_ExcInfo, Tuple[None, None, None]] + + class StartResponse(Protocol): + def __call__( + self, + __status: str, + __headers: list[tuple[str, str]], + __exc_info: _OptExcInfo | None = ..., + ) -> Callable[[bytes], object]: + ... + + WSGIEnvironment: TypeAlias = Dict[str, Any] + WSGIApplication: TypeAlias = Callable[ + [WSGIEnvironment, StartResponse], Iterable[bytes] + ] + + class InputStream(Protocol): + def read(self, __size: int = ...) -> bytes: + ... + + def readline(self, __size: int = ...) -> bytes: + ... + + def readlines(self, __hint: int = ...) -> list[bytes]: + ... + + def __iter__(self) -> Iterator[bytes]: + ... + + class ErrorStream(Protocol): + def flush(self) -> object: + ... + + def write(self, __s: str) -> object: + ... + + def writelines(self, __seq: list[str]) -> object: + ... + + class _Readable(Protocol): + def read(self, __size: int = ...) -> bytes: + ... + + # Optional: def close(self) -> object: ... + + class FileWrapper(Protocol): + def __call__( + self, __file: _Readable, __block_size: int = ... + ) -> Iterable[bytes]: + ... diff --git a/src/whitenoise/compress.py b/src/whitenoise/compress.py index 74b98ec5..5efba4f5 100644 --- a/src/whitenoise/compress.py +++ b/src/whitenoise/compress.py @@ -5,6 +5,7 @@ import os import re from io import BytesIO +from typing import Callable, Generator, Pattern, Sequence try: import brotli @@ -14,6 +15,10 @@ brotli_installed = False +def noop_log(message: str) -> None: + pass + + class Compressor: # Extensions that it's not worth trying to compress @@ -53,8 +58,13 @@ class Compressor: ) def __init__( - self, extensions=None, use_gzip=True, use_brotli=True, log=print, quiet=False - ): + self, + extensions: Sequence[str] | None = None, + use_gzip: bool = True, + use_brotli: bool = True, + log: Callable[[str], None] = print, + quiet: bool = False, + ) -> None: if extensions is None: extensions = self.SKIP_COMPRESS_EXTENSIONS self.extension_re = self.get_extension_re(extensions) @@ -62,9 +72,11 @@ def __init__( self.use_brotli = use_brotli and brotli_installed if not quiet: self.log = log + else: + self.log = noop_log @staticmethod - def get_extension_re(extensions): + def get_extension_re(extensions: Sequence[str]) -> Pattern[str]: if not extensions: return re.compile("^$") else: @@ -72,13 +84,10 @@ def get_extension_re(extensions): r"\.({})$".format("|".join(map(re.escape, extensions))), re.IGNORECASE ) - def should_compress(self, filename): + def should_compress(self, filename: str) -> bool: return not self.extension_re.search(filename) - def log(self, message): - pass - - def compress(self, path): + def compress(self, path: str) -> Generator[str, None, None]: with open(path, "rb") as f: stat_result = os.fstat(f.fileno()) data = f.read() @@ -96,7 +105,7 @@ def compress(self, path): yield self.write_data(path, compressed, ".gz", stat_result) @staticmethod - def compress_gzip(data): + def compress_gzip(data: bytes) -> bytes: output = BytesIO() # Explicitly set mtime to 0 so gzip content is fully determined # by file content (0 = "no timestamp" according to gzip spec) @@ -107,10 +116,13 @@ def compress_gzip(data): return output.getvalue() @staticmethod - def compress_brotli(data): - return brotli.compress(data) + def compress_brotli(data: bytes) -> bytes: + result: bytes = brotli.compress(data) + return result - def is_compressed_effectively(self, encoding_name, path, orig_size, data): + def is_compressed_effectively( + self, encoding_name: str, path: str, orig_size: int, data: bytes + ) -> bool: compressed_size = len(data) if orig_size == 0: is_effective = False @@ -127,7 +139,9 @@ def is_compressed_effectively(self, encoding_name, path, orig_size, data): self.log(f"Skipping {path} ({encoding_name} compression not effective)") return is_effective - def write_data(self, path, data, suffix, stat_result): + def write_data( + self, path: str, data: bytes, suffix: str, stat_result: os.stat_result + ) -> str: filename = path + suffix with open(filename, "wb") as f: f.write(data) @@ -135,7 +149,7 @@ def write_data(self, path, data, suffix, stat_result): return filename -def main(argv=None): +def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser( description="Search for all files inside *not* matching " " and produce compressed versions with " diff --git a/src/whitenoise/middleware.py b/src/whitenoise/middleware.py index 23eb4382..402b31e4 100644 --- a/src/whitenoise/middleware.py +++ b/src/whitenoise/middleware.py @@ -135,7 +135,7 @@ def serve(static_file, request): http_response[key] = value return http_response - def add_files_from_finders(self): + def add_files_from_finders(self) -> None: files = {} for finder in finders.get_finders(): for path, storage in finder.list(None): @@ -163,7 +163,7 @@ def candidate_paths_for_url(self, url): for path in paths: yield path - def immutable_file_test(self, path, url): + def immutable_file_test(self, path, url) -> bool: """ Determine whether given URL represents an immutable file (i.e. a file with a hash of its contents as part of its name) which can diff --git a/src/whitenoise/responders.py b/src/whitenoise/responders.py index 9501ea65..af5c2f53 100644 --- a/src/whitenoise/responders.py +++ b/src/whitenoise/responders.py @@ -9,6 +9,7 @@ from http import HTTPStatus from io import BufferedIOBase from time import mktime +from typing import Callable, Pattern, Sequence from urllib.parse import quote from wsgiref.headers import Headers @@ -16,7 +17,12 @@ class Response: __slots__ = ("status", "headers", "file") - def __init__(self, status, headers, file): + def __init__( + self, + status: HTTPStatus, + headers: Sequence[tuple[str, str]], + file: BufferedIOBase | None, + ) -> None: self.status = status self.headers = headers self.file = file @@ -47,15 +53,15 @@ class SlicedFile(BufferedIOBase): been reached. """ - def __init__(self, fileobj, start, end): + def __init__(self, fileobj: BufferedIOBase, start: int, end: int) -> None: fileobj.seek(start) self.fileobj = fileobj self.remaining = end - start + 1 - def read(self, size=-1): + def read(self, size: int | None = -1) -> bytes: if self.remaining <= 0: return b"" - if size < 0: + if size is None or size < 0: size = self.remaining else: size = min(size, self.remaining) @@ -63,20 +69,26 @@ def read(self, size=-1): self.remaining -= len(data) return data - def close(self): + def close(self) -> None: self.fileobj.close() class StaticFile: - def __init__(self, path, headers, encodings=None, stat_cache=None): + def __init__( + self, + path: str, + headers: list[tuple[str, str]], + encodings: dict[str, str] | None = None, + stat_cache: dict[str, os.stat_result] | None = None, + ) -> None: files = self.get_file_stats(path, encodings, stat_cache) - headers = self.get_headers(headers, files) - self.last_modified = parsedate(headers["Last-Modified"]) - self.etag = headers["ETag"] - self.not_modified_response = self.get_not_modified_response(headers) - self.alternatives = self.get_alternatives(headers, files) + parsed_headers = self.get_headers(headers, files) + self.last_modified = parsedate(parsed_headers["Last-Modified"]) + self.etag = parsed_headers["ETag"] + self.not_modified_response = self.get_not_modified_response(parsed_headers) + self.alternatives = self.get_alternatives(parsed_headers, files) - def get_response(self, method, request_headers): + def get_response(self, method: str, request_headers: dict[str, str]) -> Response: if method not in ("GET", "HEAD"): return NOT_ALLOWED_RESPONSE if self.is_not_modified(request_headers): @@ -87,7 +99,7 @@ def get_response(self, method, request_headers): else: file_handle = None range_header = request_headers.get("HTTP_RANGE") - if range_header: + if range_header is not None: try: return self.get_range_response(range_header, headers, file_handle) except ValueError: @@ -97,7 +109,12 @@ def get_response(self, method, request_headers): pass return Response(HTTPStatus.OK, headers, file_handle) - def get_range_response(self, range_header, base_headers, file_handle): + def get_range_response( + self, + range_header: str, + base_headers: list[tuple[str, str]], + file_handle: BufferedIOBase | None, + ) -> Response: headers = [] for item in base_headers: if item[0] == "Content-Length": @@ -113,7 +130,7 @@ def get_range_response(self, range_header, base_headers, file_handle): headers.append(("Content-Length", str(end - start + 1))) return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle) - def get_byte_range(self, range_header, size): + def get_byte_range(self, range_header: str, size: int) -> tuple[int, int]: start, end = self.parse_byte_range(range_header) if start < 0: start = max(start + size, 0) @@ -124,7 +141,7 @@ def get_byte_range(self, range_header, size): return start, end @staticmethod - def parse_byte_range(range_header): + def parse_byte_range(range_header: str) -> tuple[int, int | None]: units, _, range_spec = range_header.strip().partition("=") if units != "bytes": raise ValueError() @@ -142,7 +159,9 @@ def parse_byte_range(range_header): return start, end @staticmethod - def get_range_not_satisfiable_response(file_handle, size): + def get_range_not_satisfiable_response( + file_handle: BufferedIOBase | None, size: int + ) -> Response: if file_handle is not None: file_handle.close() return Response( @@ -152,9 +171,13 @@ def get_range_not_satisfiable_response(file_handle, size): ) @staticmethod - def get_file_stats(path, encodings, stat_cache): + def get_file_stats( + path: str, + encodings: dict[str, str] | None, + stat_cache: dict[str, os.stat_result] | None, + ) -> dict[str | None, FileEntry]: # Primary file has an encoding of None - files = {None: FileEntry(path, stat_cache)} + files: dict[str | None, FileEntry] = {None: FileEntry(path, stat_cache)} if encodings: for encoding, alt_path in encodings.items(): try: @@ -163,7 +186,9 @@ def get_file_stats(path, encodings, stat_cache): continue return files - def get_headers(self, headers_list, files): + def get_headers( + self, headers_list: list[tuple[str, str]], files: dict[str | None, FileEntry] + ) -> Headers: headers = Headers(headers_list) main_file = files[None] if len(files) > 1: @@ -182,17 +207,21 @@ def get_headers(self, headers_list, files): return headers @staticmethod - def get_not_modified_response(headers): - not_modified_headers = [] + def get_not_modified_response(headers: Headers) -> Response: + not_modified_headers: list[tuple[str, str]] = [] for key in NOT_MODIFIED_HEADERS: if key in headers: - not_modified_headers.append((key, headers[key])) + value = headers[key] + assert value is not None + not_modified_headers.append((key, value)) return Response( status=HTTPStatus.NOT_MODIFIED, headers=not_modified_headers, file=None ) @staticmethod - def get_alternatives(base_headers, files): + def get_alternatives( + base_headers: Headers, files: dict[str | None, FileEntry] + ) -> list[tuple[Pattern[str], str, list[tuple[str, str]]]]: # Sort by size so that the smallest compressed alternative matches first alternatives = [] files_by_size = sorted(files.items(), key=lambda i: i[1].size) @@ -207,7 +236,7 @@ def get_alternatives(base_headers, files): alternatives.append((encoding_re, file_entry.path, headers.items())) return alternatives - def is_not_modified(self, request_headers): + def is_not_modified(self, request_headers: dict[str, str]) -> bool: previous_etag = request_headers.get("HTTP_IF_NONE_MATCH") if previous_etag is not None: return previous_etag == self.etag @@ -222,7 +251,9 @@ def is_not_modified(self, request_headers): return last_requested_ts >= self.last_modified return False - def get_path_and_headers(self, request_headers): + def get_path_and_headers( + self, request_headers: dict[str, str] + ) -> tuple[str, list[tuple[str, str]]]: accept_encoding = request_headers.get("HTTP_ACCEPT_ENCODING", "") if accept_encoding == "*": accept_encoding = "" @@ -230,15 +261,19 @@ def get_path_and_headers(self, request_headers): for encoding_re, path, headers in self.alternatives: if encoding_re.search(accept_encoding): return path, headers + raise AssertionError("Unreachable") class Redirect: - def __init__(self, location, headers=None): - headers = list(headers.items()) if headers else [] - headers.append(("Location", quote(location.encode("utf8")))) - self.response = Response(HTTPStatus.FOUND, headers, None) + def __init__(self, location: str, headers: dict[str, str] | None = None) -> None: + if headers is None: + header_list = [] + else: + header_list = list(headers.items()) + header_list.append(("Location", quote(location.encode("utf8")))) + self.response = Response(HTTPStatus.FOUND, header_list, None) - def get_response(self, method, request_headers): + def get_response(self, method: str, request_headers: dict[str, str]) -> Response: return self.response @@ -257,15 +292,23 @@ class IsDirectoryError(MissingFileError): class FileEntry: __slots__ = ("path", "size", "mtime") - def __init__(self, path, stat_cache=None): + def __init__( + self, path: str, stat_cache: dict[str, os.stat_result] | None = None + ) -> None: self.path = path - stat_function = os.stat if stat_cache is None else stat_cache.__getitem__ + stat_function: Callable[[str], os.stat_result] + if stat_cache is None: + stat_function = os.stat + else: + stat_function = stat_cache.__getitem__ stat = self.stat_regular_file(path, stat_function) self.size = stat.st_size self.mtime = stat.st_mtime @staticmethod - def stat_regular_file(path, stat_function): + def stat_regular_file( + path: str, stat_function: Callable[[str], os.stat_result] + ) -> os.stat_result: """ Wrap `stat_function` to raise appropriate errors if `path` is not a regular file diff --git a/src/whitenoise/runserver_nostatic/management/commands/runserver.py b/src/whitenoise/runserver_nostatic/management/commands/runserver.py index 484b1442..83f0f65f 100644 --- a/src/whitenoise/runserver_nostatic/management/commands/runserver.py +++ b/src/whitenoise/runserver_nostatic/management/commands/runserver.py @@ -8,24 +8,29 @@ """ from __future__ import annotations +import argparse from importlib import import_module +from typing import Generator from django.apps import apps +from django.core.management import BaseCommand -def get_next_runserver_command(): +def get_next_runserver_command() -> type[BaseCommand]: """ Return the next highest priority "runserver" command class """ for app_name in get_lower_priority_apps(): module_path = "%s.management.commands.runserver" % app_name try: - return import_module(module_path).Command + klass: type[BaseCommand] = import_module(module_path).Command + return klass except (ImportError, AttributeError): pass + raise ValueError("No lower priority app has a 'runserver' command") -def get_lower_priority_apps(): +def get_lower_priority_apps() -> Generator[str, None, None]: """ Yield all app module names below the current app in the INSTALLED_APPS list """ @@ -42,11 +47,12 @@ def get_lower_priority_apps(): RunserverCommand = get_next_runserver_command() -class Command(RunserverCommand): - def add_arguments(self, parser): +class Command(RunserverCommand): # type: ignore [misc,valid-type] + def add_arguments(self, parser: argparse.ArgumentParser) -> None: super().add_arguments(parser) if parser.get_default("use_static_handler") is True: parser.set_defaults(use_static_handler=False) + assert parser.description is not None parser.description += ( "\n(Wrapped by 'whitenoise.runserver_nostatic' to always" " enable '--nostatic')" diff --git a/src/whitenoise/storage.py b/src/whitenoise/storage.py index ae3d7de7..4e5076a7 100644 --- a/src/whitenoise/storage.py +++ b/src/whitenoise/storage.py @@ -35,7 +35,7 @@ def fallback_post_process(self, paths, dry_run=False, **options): for path in paths: yield path, None, False - def create_compressor(self, **kwargs): + def create_compressor(self, **kwargs) -> Compressor: return Compressor(**kwargs) def post_process_with_compression(self, files): @@ -70,13 +70,12 @@ class CompressedManifestStaticFilesStorage(ManifestStaticFilesStorage): those without the hash in their name) """ - _new_files = None - def __init__(self, *args, **kwargs): manifest_strict = getattr(settings, "WHITENOISE_MANIFEST_STRICT", None) if manifest_strict is not None: self.manifest_strict = manifest_strict super().__init__(*args, **kwargs) + self._new_files: set[str] | None = None def post_process(self, *args, **kwargs): files = super().post_process(*args, **kwargs) @@ -97,7 +96,7 @@ def post_process_with_compression(self, files): # file is yielded we have to hook in to the `hashed_name` method to # keep track of them all. hashed_names = {} - new_files = set() + new_files: set[str] = set() self.start_tracking_new_files(new_files) for name, hashed_name, processed in files: if hashed_name and not isinstance(processed, Exception): @@ -122,17 +121,17 @@ def hashed_name(self, *args, **kwargs): self._new_files.add(self.clean_name(name)) return name - def start_tracking_new_files(self, new_files): + def start_tracking_new_files(self, new_files: set[str]) -> None: self._new_files = new_files - def stop_tracking_new_files(self): + def stop_tracking_new_files(self) -> None: self._new_files = None @property - def keep_only_hashed_files(self): + def keep_only_hashed_files(self) -> bool: return getattr(settings, "WHITENOISE_KEEP_ONLY_HASHED_FILES", False) - def delete_files(self, files_to_delete): + def delete_files(self, files_to_delete) -> None: for name in files_to_delete: try: os.unlink(self.path(name)) @@ -140,7 +139,7 @@ def delete_files(self, files_to_delete): if e.errno != errno.ENOENT: raise - def create_compressor(self, **kwargs): + def create_compressor(self, **kwargs) -> Compressor: return Compressor(**kwargs) def compress_files(self, names): @@ -154,7 +153,7 @@ def compress_files(self, names): compressed_name = compressed_path[prefix_len:] yield name, compressed_name - def make_helpful_exception(self, exception, name): + def make_helpful_exception(self, exception: Exception, name: str) -> Exception: """ If a CSS file contains references to images, fonts etc that can't be found then Django's `post_process` blows up with a not particularly helpful diff --git a/src/whitenoise/string_utils.py b/src/whitenoise/string_utils.py index b56d2624..62e14a23 100644 --- a/src/whitenoise/string_utils.py +++ b/src/whitenoise/string_utils.py @@ -1,21 +1,26 @@ from __future__ import annotations +from typing import Any -def decode_if_byte_string(s, force_text=False): - if isinstance(s, bytes): - s = s.decode() - if force_text and not isinstance(s, str): - s = str(s) - return s + +def decode_if_byte_string(value: Any, force_text: bool = False) -> str: + result: str + if isinstance(value, bytes): + result = value.decode() + elif force_text and not isinstance(value, str): + result = str(value) + else: + result = value + return result # Follow Django in treating URLs as UTF-8 encoded (which requires undoing the # implicit ISO-8859-1 decoding applied in Python 3). Strictly speaking, URLs # should only be ASCII anyway, but UTF-8 can be found in the wild. -def decode_path_info(path_info): +def decode_path_info(path_info: str) -> str: return path_info.encode("iso-8859-1", "replace").decode("utf-8", "replace") -def ensure_leading_trailing_slash(path): +def ensure_leading_trailing_slash(path: str | None) -> str: path = (path or "").strip("/") return f"/{path}/" if path else "/" diff --git a/tests/django_urls.py b/tests/django_urls.py index 562b8d6f..a39253fb 100644 --- a/tests/django_urls.py +++ b/tests/django_urls.py @@ -1,3 +1,5 @@ from __future__ import annotations -urlpatterns = [] +from django.urls import URLPattern + +urlpatterns: list[URLPattern] = [] diff --git a/tests/test_compress.py b/tests/test_compress.py index 11ea9471..b483c1df 100644 --- a/tests/test_compress.py +++ b/tests/test_compress.py @@ -70,8 +70,12 @@ def test_with_falsey_extensions(): def test_custom_log(): - compressor = Compressor(log="test") - assert compressor.log == "test" + def logit(message: str) -> None: + pass + + compressor = Compressor(log=logit) + + assert compressor.log is logit def test_compress(): @@ -82,5 +86,5 @@ def test_compress(): def test_compressed_effectively_no_orig_size(): compressor = Compressor(quiet=True) assert not compressor.is_compressed_effectively( - "test_encoding", "test_path", 0, "test_data" + "test_encoding", "test_path", 0, b"test_data" ) diff --git a/tests/test_django_whitenoise.py b/tests/test_django_whitenoise.py index 2654424f..8c5ab108 100644 --- a/tests/test_django_whitenoise.py +++ b/tests/test_django_whitenoise.py @@ -3,6 +3,7 @@ import shutil import tempfile from contextlib import closing +from typing import Any from urllib.parse import urljoin from urllib.parse import urlparse @@ -21,11 +22,11 @@ from whitenoise.middleware import WhiteNoiseMiddleware -def reset_lazy_object(obj): +def reset_lazy_object(obj: Any) -> None: obj._wrapped = empty -def get_url_path(base, url): +def get_url_path(base: str, url: str) -> str: return urlparse(urljoin(base, url)).path @@ -130,7 +131,8 @@ def finder_static_files(request): WHITENOISE_INDEX_FILE=True, STATIC_ROOT=None, ): - finders.get_finder.cache_clear() + # django-stubs doesn’t mark get_finder() as @lru_cache + finders.get_finder.cache_clear() # type: ignore [attr-defined] yield files diff --git a/tests/test_runserver_nostatic.py b/tests/test_runserver_nostatic.py index 65ed458b..da12e858 100644 --- a/tests/test_runserver_nostatic.py +++ b/tests/test_runserver_nostatic.py @@ -1,12 +1,15 @@ from __future__ import annotations +from django.core.management import BaseCommand from django.core.management import get_commands from django.core.management import load_command_class -def get_command_instance(name): +def get_command_instance(name: str) -> BaseCommand: app_name = get_commands()[name] - return load_command_class(app_name, name) + # django-stubs incorrect type for get_commands() fixed in: + # https://github.com/typeddjango/django-stubs/pull/1074 + return load_command_class(app_name, name) # type: ignore [arg-type] def test_command_output(): diff --git a/tests/test_storage.py b/tests/test_storage.py index 5ca7c224..eb6caf91 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -61,7 +61,7 @@ class TriggerException(HashedFilesMixin): def exists(self, path): return False - exception = None + exception: ValueError try: TriggerException().hashed_name("/missing/file.png") except ValueError as e: diff --git a/tests/test_whitenoise.py b/tests/test_whitenoise.py index 3e98f030..be0cef2a 100644 --- a/tests/test_whitenoise.py +++ b/tests/test_whitenoise.py @@ -13,6 +13,7 @@ from wsgiref.simple_server import demo_app import pytest +import requests from .utils import AppServer from .utils import Files @@ -71,7 +72,7 @@ def server(application): yield app_server -def assert_is_default_response(response): +def assert_is_default_response(response: requests.Response) -> None: assert "Hello world!" in response.text @@ -310,7 +311,7 @@ def test_no_error_on_very_long_filename(server): assert response.status_code != 500 -def copytree(src, dst): +def copytree(src: str, dst: str) -> None: for name in os.listdir(src): src_path = os.path.join(src, name) dst_path = os.path.join(dst, name) diff --git a/tests/utils.py b/tests/utils.py index 0db45d75..8974f891 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,17 +2,21 @@ import os import threading +from typing import Any +from typing import Iterable from wsgiref.simple_server import make_server from wsgiref.simple_server import WSGIRequestHandler from wsgiref.util import shift_path_info import requests +from whitenoise.compat import StartResponse, WSGIApplication, WSGIEnvironment + TEST_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_files") class SilentWSGIHandler(WSGIRequestHandler): - def log_message(*args): + def log_message(self, format: str, *args: Any) -> None: pass @@ -24,13 +28,15 @@ class AppServer: PREFIX = "subdir" - def __init__(self, application): + def __init__(self, application: WSGIApplication) -> None: self.application = application self.server = make_server( "127.0.0.1", 0, self.serve_under_prefix, handler_class=SilentWSGIHandler ) - def serve_under_prefix(self, environ, start_response): + def serve_under_prefix( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> Iterable[bytes]: prefix = shift_path_info(environ) if prefix != self.PREFIX: start_response("404 Not Found", []) @@ -38,10 +44,12 @@ def serve_under_prefix(self, environ, start_response): else: return self.application(environ, start_response) - def get(self, *args, **kwargs): + def get(self, *args: Any, **kwargs: Any) -> requests.Response: return self.request("get", *args, **kwargs) - def request(self, method, path, *args, **kwargs): + def request( + self, method: str, path: str, *args: Any, **kwargs: Any + ) -> requests.Response: url = "http://{0[0]}:{0[1]}{1}".format(self.server.server_address, path) thread = threading.Thread(target=self.server.handle_request) thread.start() @@ -49,12 +57,12 @@ def request(self, method, path, *args, **kwargs): thread.join() return response - def close(self): + def close(self) -> None: self.server.server_close() class Files: - def __init__(self, directory, **files): + def __init__(self, directory: str, **files: str) -> None: self.directory = os.path.join(TEST_FILE_PATH, directory) for name, path in files.items(): url = f"/{AppServer.PREFIX}/{path}"