Skip to content

Commit

Permalink
many more types including wsgiref.types
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz committed Nov 3, 2022
1 parent 706ab1c commit f6dafc9
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 18 deletions.
32 changes: 21 additions & 11 deletions src/whitenoise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import re
import warnings
from posixpath import normpath
from typing import Callable, Generator
from typing import Callable, Generator, Iterable
from wsgiref.headers import Headers
from wsgiref.util import FileWrapper

from .compat import StartResponse, WSGIApplication, WSGIEnvironment
from .media_types import MediaTypes
from .responders import IsDirectoryError, MissingFileError, Redirect, StaticFile
from .string_utils import (
Expand All @@ -25,7 +26,7 @@ class WhiteNoise:

def __init__(
self,
application,
application: WSGIApplication,
root: str | None = None,
prefix: str | None = None,
*,
Expand Down Expand Up @@ -65,12 +66,14 @@ def __init__(

self.media_types = MediaTypes(extra_types=mimetypes)
self.application = application
self.files = {}
self.directories = []
self.files: dict[str, Redirect | StaticFile] = {}
self.directories: list[tuple[str, str]] = []
if root is not None:
self.add_files(root, prefix)

def __call__(self, environ, start_response):
def __call__(
self, environ: WSGIEnvironment, start_response: StartResponse
) -> Iterable[bytes]:
path = decode_path_info(environ.get("PATH_INFO", ""))
if self.autorefresh:
static_file = self.find_file(path)
Expand All @@ -82,12 +85,18 @@ def __call__(self, environ, start_response):
return self.serve(static_file, environ, start_response)

@staticmethod
def serve(static_file, environ, start_response):
def serve(
static_file: Redirect | StaticFile,
environ: WSGIEnvironment,
start_response: StartResponse,
) -> Iterable[bytes]:
response = static_file.get_response(environ["REQUEST_METHOD"], environ)
status_line = f"{response.status} {response.status.phrase}"
start_response(status_line, list(response.headers))
if response.file is not None:
file_wrapper = environ.get("wsgi.file_wrapper", FileWrapper)
file_wrapper: type[FileWrapper] = environ.get(
"wsgi.file_wrapper", FileWrapper
)
return file_wrapper(response.file)
else:
return []
Expand Down Expand Up @@ -139,14 +148,15 @@ def add_file_to_dictionary(
def find_file(self, url: str) -> Redirect | StaticFile | None:
# Optimization: bail early if the URL can never match a file
if not self.index_file and url.endswith("/"):
return
return None
if not self.url_is_canonical(url):
return
return None
for path in self.candidate_paths_for_url(url):
try:
return self.find_file_at_path(path, url)
except MissingFileError:
pass
return None

def candidate_paths_for_url(self, url: str) -> Generator[str, None, None]:
for root, prefix in self.directories:
Expand Down Expand Up @@ -181,7 +191,7 @@ def find_file_at_path_with_indexes(
raise MissingFileError(path)

@staticmethod
def url_is_canonical(url):
def url_is_canonical(url: str) -> bool:
"""
Check that the URL path is in canonical format i.e. has normalised
slashes and no path traversal elements
Expand All @@ -195,7 +205,7 @@ def url_is_canonical(url):

@staticmethod
def is_compressed_variant(
path, stat_cache: dict[str, os.stat_result] | None = None
path: str, stat_cache: dict[str, os.stat_result] | None = None
) -> bool:
if path[-3:] in (".gz", ".br"):
uncompressed_path = path[:-3]
Expand Down
73 changes: 73 additions & 0 deletions src/whitenoise/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import sys

if sys.version_info >= (3, 11):
from wsgiref.types import WSGIApplication
else:
from collections.abc import Callable, Iterable, Iterator
from types import TracebackType
from typing import Any, Dict, Protocol, Tuple, Type, Union

from typing_extensions import TypeAlias

__all__ = [
"StartResponse",
"WSGIEnvironment",
"WSGIApplication",
"InputStream",
"ErrorStream",
"FileWrapper",
]

_ExcInfo: TypeAlias = Tuple[Type[BaseException], BaseException, TracebackType]
_OptExcInfo: TypeAlias = Union[_ExcInfo, Tuple[None, None, None]]

class StartResponse(Protocol):
def __call__(
self,
__status: str,
__headers: list[tuple[str, str]],
__exc_info: _OptExcInfo | None = ...,
) -> Callable[[bytes], object]:
...

WSGIEnvironment: TypeAlias = Dict[str, Any]
WSGIApplication: TypeAlias = Callable[
[WSGIEnvironment, StartResponse], Iterable[bytes]
]

class InputStream(Protocol):
def read(self, __size: int = ...) -> bytes:
...

def readline(self, __size: int = ...) -> bytes:
...

def readlines(self, __hint: int = ...) -> list[bytes]:
...

def __iter__(self) -> Iterator[bytes]:
...

class ErrorStream(Protocol):
def flush(self) -> object:
...

def write(self, __s: str) -> object:
...

def writelines(self, __seq: list[str]) -> object:
...

class _Readable(Protocol):
def read(self, __size: int = ...) -> bytes:
...

# Optional: def close(self) -> object: ...

class FileWrapper(Protocol):
def __call__(
self, __file: _Readable, __block_size: int = ...
) -> Iterable[bytes]:
...
5 changes: 4 additions & 1 deletion src/whitenoise/responders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ class Response:
__slots__ = ("status", "headers", "file")

def __init__(
self, status: int, headers: Sequence[tuple[str, str]], file: BinaryIO | None
self,
status: HTTPStatus,
headers: Sequence[tuple[str, str]],
file: BinaryIO | None,
) -> None:
self.status = status
self.headers = headers
Expand Down
5 changes: 3 additions & 2 deletions tests/test_django_whitenoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import tempfile
from contextlib import closing
from typing import Any
from urllib.parse import urljoin, urlparse

import pytest
Expand All @@ -18,11 +19,11 @@
from .utils import AppServer, Files


def reset_lazy_object(obj):
def reset_lazy_object(obj: Any) -> None:
obj._wrapped = empty


def get_url_path(base, url):
def get_url_path(base: str, url: str) -> str:
return urlparse(urljoin(base, url)).path


Expand Down
12 changes: 8 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import os
import threading
import warnings
from typing import Any
from typing import Any, Iterable
from wsgiref.simple_server import WSGIRequestHandler, make_server
from wsgiref.util import shift_path_info

import requests

from whitenoise.compat import StartResponse, WSGIApplication, WSGIEnvironment

warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="requests")


TEST_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_files")


class SilentWSGIHandler(WSGIRequestHandler):
def log_message(*args):
def log_message(self, format: str, *args: Any) -> None:
pass


Expand All @@ -28,13 +30,15 @@ class AppServer:

PREFIX = "subdir"

def __init__(self, application):
def __init__(self, application: WSGIApplication) -> None:
self.application = application
self.server = make_server(
"127.0.0.1", 0, self.serve_under_prefix, handler_class=SilentWSGIHandler
)

def serve_under_prefix(self, environ, start_response):
def serve_under_prefix(
self, environ: WSGIEnvironment, start_response: StartResponse
) -> Iterable[bytes]:
prefix = shift_path_info(environ)
if prefix != self.PREFIX:
start_response("404 Not Found", [])
Expand Down

0 comments on commit f6dafc9

Please sign in to comment.