diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 2d2ecd6..0000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -.git/ diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml new file mode 100644 index 0000000..64af757 --- /dev/null +++ b/.github/workflows/test-suite.yml @@ -0,0 +1,41 @@ +name: Test NServer + +on: + push: + branches: + - main + + pull_request: + branches: + - main + +jobs: + lint: + name: "Python Lint" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v3 + + - name: Lint with tox + run: uvx --with tox-uv tox -e lint + + test: + name: "Python Test ${{ matrix.os }}" + needs: [lint] + runs-on: "${{ matrix.os }}" + strategy: + fail-fast: false # allow tests to run on all platforms + matrix: + os: + - ubuntu-latest + - windows-latest + - macos-latest + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v3 + + - name: Test with tox + run: uvx --with tox-uv tox diff --git a/.gitignore b/.gitignore index 0f424ef..a8247d8 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,4 @@ dmypy.json ### PROJECT ### ============================================================================ # Project specific stuff goes here +uv.lock diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 120000 index 0000000..d0fcfe9 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1 @@ +docs/contributing.md \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 120000 index 0000000..6ac9ff1 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1 @@ +docs/security.md \ No newline at end of file diff --git a/dev.sh b/dev.sh index 0eb753f..d95f317 100755 --- a/dev.sh +++ b/dev.sh @@ -15,13 +15,15 @@ set -e # Bail at the first sign of trouble # Notation Reference: https://unix.stackexchange.com/questions/122845/using-a-b-for-variable-assignment-in-scripts#comment685330_122848 : ${DEBUG:=0} : ${CI:=0} # Flag for if we are in CI - default to not. -: ${SKIP_BUILD:=0} # Allow some commands to forcibly skip compose-build -: ${PORT:=8000} # allows for some commands to change the port if ! command -v toml &> /dev/null; then pip install --user toml-cli fi +if ! command -v uv &> /dev/null; then + pip install --user uv +fi + ### CONTANTS ### ============================================================================ SOURCE_UID=$(id -u) @@ -45,8 +47,7 @@ PACKAGE_VERSION=$(toml get --toml-path pyproject.toml project.version) # You may want to customise these for your project # TODO: this potentially should be moved to manifest.env so that projects can easily # customise the main dev.sh -SOURCE_FILES="src tests" -PYTHON_MIN_VERSION="py37" +PYTHON_MIN_VERSION="py38" ## Build related ## ----------------------------------------------------------------------------- @@ -117,77 +118,6 @@ cp .tmp/env .env ### FUNCTIONS ### ============================================================================ -## Docker Functions -## ----------------------------------------------------------------------------- -function compose_build { - heading2 "🐋 Building $1" - if [[ "$CI" = 1 ]]; then - docker compose build --progress plain $1 - - elif [[ "$DEBUG" -gt 0 ]]; then - docker compose build --progress plain $1 - - else - docker compose build $1 - fi - echo -} - -function compose_run { - heading2 "🐋 running $@" - docker compose -f docker-compose.yml run --rm "$@" - echo -} - -function docker_clean { - heading2 "🐋 Removing $PACKAGE_NAME images" - IMAGES=$(docker images --filter "reference=${PACKAGE_NAME}-asdf*" | tail -n +2) - COUNT_IMAGES=$(echo -n "$IMAGES" | wc -l) - if [[ "$DEBUG" -gt 0 ]]; then - echo "IMAGES=$IMAGES" - echo "COUNT_IMAGES=$COUNT_IMAGES" - fi - - if [[ "$COUNT_IMAGES" -gt 0 ]]; then - docker images | grep "$PACKAGE_NAME" | awk '{OFS=":"} {print $1, $2}' | xargs -t docker rmi - fi -} - - -function docker_clean_unused { - docker images --filter "reference=${PACKAGE_NAME}-*" -a | \ - tail -n +2 | \ - grep -v "$GIT_COMMIT" | \ - awk '{OFS=":"} {print $1, $2}' | \ - xargs -t docker rmi -} - -function docker_autoclean { - if [[ "$CI" = 0 ]]; then - if [[ "$DEBUG" -gt 0 ]]; then - heading2 "🐋 determining if need to clean" - fi - - IMAGES=$( - docker images --filter "reference=${PACKAGE_NAME}-*" -a |\ - tail -n +2 |\ - grep -v "$GIT_COMMIT" ;\ - /bin/true - ) - COUNT_IMAGES=$(echo "$IMAGES" | wc -l) - - if [[ "$DEBUG" -gt 0 ]]; then - echo "IMAGES=${IMAGES}" - echo "COUNT_IMAGES=${COUNT_IMAGES}" - fi - - if [[ $COUNT_IMAGES -gt $AUTOCLEAN_LIMIT ]]; then - heading2 "Removing unused ${PACKAGE_NAME} images 🐋" - docker_clean_unused - fi - fi -} - ## Utility ## ----------------------------------------------------------------------------- function heading { @@ -228,32 +158,6 @@ function check_pyproject_toml { ## Command Functions ## ----------------------------------------------------------------------------- -function command_build { - if [[ -z "$1" || "$1" == "dist" ]]; then - BUILD_DIR="dist" - elif [[ "$1" == "tmp" ]]; then - BUILD_DIR=".tmp/dist" - else - return 1 - fi - - # TODO: unstashed changed guard - - if [[ ! -d "$BUILD_DIR" ]]; then - heading "setup 📜" - mkdir $BUILD_DIR - fi - - echo "BUILD_DIR=${BUILD_DIR}" >> .env - echo "BUILD_DIR=${BUILD_DIR}" >> .tmp/env - - heading "build 🐍" - # Note: we always run compose_build because we copy the package source code to - # the container so we can modify it without affecting local source code. - compose_build python-build - compose_run python-build -} - function display_usage { echo "dev.sh - development utility" @@ -306,70 +210,30 @@ case $1 in echo "ERROR! Do not run format in CI!" exit 250 fi - heading "black 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - - compose_run python-common \ - black --line-length 100 --target-version ${PYTHON_MIN_VERSION} $SOURCE_FILES + heading "tox 🐍 - format" + uvx tox -e format || true ;; "lint") - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - - if [[ "$DEBUG" -gt 0 ]]; then - heading2 "🤔 Debugging" - compose_run python-common ls -lah - compose_run python-common pip list - fi - - heading "validate-pyproject 🐍" - compose_run python-common validate-pyproject pyproject.toml - - heading "black - check only 🐍" - compose_run python-common \ - black --line-length 100 --target-version ${PYTHON_MIN_VERSION} --check --diff $SOURCE_FILES - - heading "pylint 🐍" - compose_run python-common pylint -j 4 --output-format=colorized $SOURCE_FILES - - heading "mypy 🐍" - compose_run python-common mypy $SOURCE_FILES - + heading "tox 🐍 - lint" + uvx tox -e lint || true ;; "test") - command_build tmp - - heading "tox 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-tox - fi - compose_run python-tox tox -e ${PYTHON_MIN_VERSION} || true - - rm -rf .tmp/dist/* + heading "tox 🐍 - single" + uvx tox -e py312 || true ;; "test-full") - command_build tmp - - heading "tox 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-tox - fi - compose_run python-tox tox || true - - rm -rf .tmp/dist/* + heading "tox 🐍 - all" + uvx tox || true ;; "build") - command_build dist + source ./lib/python/build.sh ;; @@ -408,44 +272,20 @@ print('Your package is already imported 🎉\nPress ctrl+d to exit') EOF fi - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run python-common bpython --config bpython.ini -i .tmp/repl.py - - ;; - - "run") - heading "Running 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run python-common "${@:2}" + uv run python -i .tmp/repl.py ;; "docs") heading "Preview Docs 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run -p 127.0.0.1:${PORT}:8080 python-common mkdocs serve -a 0.0.0.0:8080 -w docs + uv run --extra dev mkdocs serve -w docs ;; "build-docs") heading "Building Docs 🐍" - if [[ -z "$VIRTUAL_ENV" ]]; then - echo "This command should be run in a virtual environment to avoid poluting" - exit 1 - fi - - if [[ -z $(pip3 list | grep mike) ]]; then - pip install -e.[docs] - fi - - mike deploy "$PACKAGE_VERSION" "latest" \ + uv run --extra dev mike deploy "$PACKAGE_VERSION" "latest" \ --update-aliases \ --prop-set-string "git_branch=${GIT_BRANCH}" \ --prop-set-string "git_commit=${GIT_COMMIT}" \ @@ -462,7 +302,6 @@ EOF "clean") heading "Cleaning 📜" - docker_clean echo "🐍 pyclean" if ! command -v pyclean &> /dev/null; then @@ -471,6 +310,9 @@ EOF pyclean src pyclean tests + echo "🐍 clear .tox" + rm -rf .tox + echo "🐍 remove build artifacts" rm -rf build dist "src/${PACKAGE_PYTHON_NAME}.egg-info" @@ -523,5 +365,3 @@ EOF ;; esac - -docker_autoclean diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 86a3c9e..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,42 +0,0 @@ -version: "3.1" -services: - python-common: &pythonBase - image: "${PACKAGE_NAME}-python-general:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/common.Dockerfile - args: &pythonBaseBuildArgs - - "SOURCE_UID=${SOURCE_UID}" - - "SOURCE_GID=${SOURCE_GID}" - - "SOURCE_UID_GID=${SOURCE_UID_GID}" - user: devuser - working_dir: /code - env_file: - - .tmp/env - environment: - - "PATH=/home/devuser/.local/bin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" - volumes: - - .:/code - - python-build: - <<: *pythonBase - image: "${PACKAGE_NAME}-python-build:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/build.Dockerfile - args: *pythonBaseBuildArgs - command: "/code/lib/python/build.sh" - volumes: - - ./${BUILD_DIR}:/code/dist - - python-tox: - <<: *pythonBase - image: "${PACKAGE_NAME}-python-tox:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/tox.Dockerfile - args: *pythonBaseBuildArgs - volumes: - - ./${BUILD_DIR}:/code/dist - - ./tests:/code/tests - - ./tox.ini:/code/tox.ini diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..1299381 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,9 @@ +# Architecture + +## Request Flow + +![NServer Request Flow](assets/images/nserver-architecture-request-flow.drawio.png) + +## Server Middleware + +![NServer Server Middleware Flow](assets/images/nserver-architecture-server-flow.drawio.png) diff --git a/docs/assets/images/nserver-architecture-request-flow.drawio.png b/docs/assets/images/nserver-architecture-request-flow.drawio.png new file mode 100644 index 0000000..358a06c Binary files /dev/null and b/docs/assets/images/nserver-architecture-request-flow.drawio.png differ diff --git a/docs/assets/images/nserver-architecture-server-flow.drawio.png b/docs/assets/images/nserver-architecture-server-flow.drawio.png new file mode 100644 index 0000000..bcc6b14 Binary files /dev/null and b/docs/assets/images/nserver-architecture-server-flow.drawio.png differ diff --git a/docs/blueprints.md b/docs/blueprints.md deleted file mode 100644 index 90462ec..0000000 --- a/docs/blueprints.md +++ /dev/null @@ -1,54 +0,0 @@ -# Blueprints - -[`Blueprint`][nserver.server.Blueprint]s provide a way for you to compose your application. They support most of the same functionality as a `NameServer`. - -Use cases: - -- Split up your application across different blueprints for maintainability / composability. -- Reuse a blueprint registered under different rules. -- Allow custom packages to define their own rules that you can add to your own server. - -Blueprints require `nserver>=2.0` - -## Using Blueprints - -```python -from nserver import Blueprint, NameServer, ZoneRule, ALL_CTYPES, A - -# First Blueprint -mysite = Blueprint("mysite") - -@mysite.rule("nicholashairs.com", ["A"]) -@mysite.rule("www.nicholashairs.com", ["A"]) -def nicholashairs_website(query: Query) -> A: - return A(query.name, "159.65.13.73") - -@mysite.rule(ZoneRule, "", ALL_CTYPES) -def nicholashairs_catchall(query: Query) -> None: - # Return empty response for all other queries - return None - -# Second Blueprint -en_blueprint = Blueprint("english-speaking-blueprint") - -@en_blueprint.rule("hello.{base_domain}", ["A"]) -def en_hello(query: Query) -> A: - return A(query.name, "1.1.1.1") - -# Register to NameServer -server = NameServer("server") -server.register_blueprint(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "au", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "nz", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "uk", ALL_CTYPES) -``` - -### Middleware, Hooks, and Error Handling - -Blueprints maintain their own `QueryMiddleware` stack which will run before any rule function is run. Included in this stack is the `HookMiddleware` and `ExceptionHandlerMiddleware`. - -## Key differences with `NameServer` - -- Does not use settings (`Setting`). -- Does not have a `Transport`. -- Does not have a `RawRecordMiddleware` stack. diff --git a/docs/changelog.md b/docs/changelog.md index 9f9f034..a374b55 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,47 @@ # Change Log +All notable changes to this project will be documented in this file. -## 2.0.0 +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + + +## [3.0.0](https://github.com/nhairs/nserver/compare/v2.0.0...dev) - UNRELEASED + +!!! tip + Version `3.0.0` represents a large incompatible refactor of `nserver` with version `2.0.0` considered a ["misfire"](https://github.com/nhairs/nserver/pull/4#issuecomment-2254354192). If you have been using functionality from `2.0.0` or the development branch you should expect a large number of breaking changes. + +### Added +- Add Python 3.13 support +- Generalised CLI interface for running applications; see `nserver --help`. + - Implemented in `nserver.cli`. +- `nserver.application` classes that focus on running a given server instance. + - This lays the ground work for different ways of running servers in the future; e.g. using threads. +- `nserver.server.RawNameServer` that handles `RawMiddleware` including exception handling. + +### Removed +- Drop Python 3.7 support +- `nserver.server.SubServer` has been removed. + - `NameServer` instances can now be registered to other `NameServer` instances. + +### Changed +- Refactored `nserver.server.NameServer` + - "Raw" functionality has been removed. This has been moved to the `nserver.server.RawNameServer`. + - "Transport" and other related "Application" functionality has been removed from `NameServer` instances. This has moved to the `nserver.application` classes. + - `NameServer` instances can now be registered to other instances. This replaces `SubServer` functionality that was in development. +- Refactoring of `nserver.server` and `nserver.middleware` classes. +- `NameServer` `name` argument / attribute is no longer used when creating the logger. + +### Fixed +- Uncaught errors from dropped connections in `nserver.transport.TCPv4Transport` [#6](https://github.com/nhairs/nserver/issues/6) + +### Development Changes +- Development tooling has moved to `uv`. + - The tooling remains wrapped in `dev.sh`. + - This remove the requirement for `docker` in local development. +- Test suite added to GitHub Actions. +- Added contributing guidelies. + +## [2.0.0](https://github.com/nhairs/nserver/compare/v1.0.0...v2.0.0) - 2023-12-20 - Implement [Middleware][middleware] - This includes adding error handling middleware that facilitates [error handling][error-handling]. @@ -10,6 +51,6 @@ - Add [Blueprints][blueprints] - Include refactoring `NameServer` into a new shared based `Scaffold` class. -## 1.0.0 +## [1.0.0](https://github.com/nhairs/nserver/commit/628db055848c6543641d514b4186f8d953b6af7d) - 2023-11-03 - Beta release diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..22c8de4 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,107 @@ +# Contributing + +Contributions are welcome! + +## Code of Conduct + +In general we follow the [Python Software Foundation Code of Conduct](https://policies.python.org/python.org/code-of-conduct/). Please note that we are not affiliated with the PSF. + +## Pull Request Process + +**0. Before you begin** + +If you're not familiar with contributing to open source software, [start by reading this guide](https://opensource.guide/how-to-contribute/). + +Be aware that anything you contribute will be licenced under [the project's licence](https://github.com/nhairs/nserver/blob/main/LICENSE). If you are making a change as a part of your job, be aware that your employer might own your work and you'll need their permission in order to licence the code. + +### 1. Find something to work on + +Where possible it's best to stick to established issues where discussion has already taken place. Contributions that haven't come from a discussed issue are less likely to be accepted. + +The following are things that can be worked on without an existing issue: + +- Updating documentation. This includes fixing in-code documentation / comments, and the overall docs. +- Small changes that don't change functionality such as refactoring or adding / updating tests. + +### 2. Fork the repository and make your changes + +We don't have styling documentation, so where possible try to match existing code. This includes the use of "headings" and "dividers" (this will make sense when you look at the code). + +Common devleopment tooling has been wrapped in `dev.sh` (which uses `uv` under the hood). + +Before creating your pull request you'll want to format your code and run the linters and tests: + +```shell +# Format +./dev.sh format + +# Lint +./dev.sh lint + +# Tests +./dev.sh test +``` + +If making changes to the documentation you can preview the changes locally using `./dev.sh docs`. Changes to the README can be previewed using [`grip`](https://github.com/joeyespo/grip) (not included in `dev` dependencies). + +!!! note + In general we will always squash merge pull requests so you do not need to worry about a "clean" commit history. + +### 3. Checklist + +Before pushing and creating your pull request, you should make sure you've done the following: + +- Updated any relevant tests. +- Formatted your code and run the linters and tests. +- Updated the version number in `pyproject.toml`. In general using a `.devN` suffix is acceptable. + This is not required for changes that do no affect the code such as documentation. +- Add details of the changes to the change log (`docs/changelog.md`), creating a new section if needed. +- Add notes for new / changed features in the relevant docstring. + +**4. Create your pull request** + +When creating your pull request be aware that the title and description will be used for the final commit so pay attention to them. + +Your pull request description should include the following: + +- Why the pull request is being made +- Summary of changes +- How the pull request was tested - especially if not covered by unit testing. + +Once you've submitted your pull request make sure that all CI jobs are passing. Pull requests with failing jobs will not be reviewed. + +### 5. Code review + +Your code will be reviewed by a maintainer. + +If you're not familiar with code review start by reading [this guide](https://google.github.io/eng-practices/review/). + +!!! tip "Remember you are not your work" + + You might be asked to explain or justify your choices. This is not a criticism of your value as a person! + + Often this is because there are multiple ways to solve the same problem and the reviewer would like to understand more about the way you solved. + +## Common Topics + +### Versioning and breaking compatability + +This project uses semantic versioning. + +In general backwards compatability is always preferred. + +Feature changes MUST be compatible with all [security supported versions of Python](https://endoflife.date/python) and SHOULD be compatible with all unsupported versions of Python where [recent downloads over the last 90 days exceeds 10% of all downloads](https://pypistats.org/packages/nserver). + +In general, only the latest `major.minor` version of NServer is supported. Bug fixes and feature backports requiring a version branch may be considered but must be discussed with the maintainers first. + +See also [Security Policy](security.md). + +### Spelling + +The original implementation of this project used Australian spelling so it will continue to use Australian spelling for all code. + +Documentation is more flexible and may use a variety of English spellings. + +### Contacting the Maintainers + +In general it is preferred to keep communication to GitHub, e.g. through comments on issues and pull requests. If you do need to contact the maintainers privately, please do so using the email addresses in the maintainers section of the `pyproject.toml`. diff --git a/docs/error-handling.md b/docs/error-handling.md index a27c2e5..a4c1d37 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -1,10 +1,8 @@ # Error Handling -Custom exception handling is handled through the [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] and [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. +Custom exception handling is handled through the [`QueryExceptionHandlerMiddleware`][nserver.middleware.QueryExceptionHandlerMiddleware] and [`RawExceptionHandlerMiddleware`][nserver.middleware.RawExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. -Error handling requires `nserver>=2.0` - -In general you are probably able to use the `ExceptionHandlerMiddleware` as the `RawRecordExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawRecordMiddleware` or broken exception handlers in the `ExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `ExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `ExceptionHandlerMiddleware`. +In general you are probably able to use the `QueryExceptionHandlerMiddleware` as the `RawExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawMiddleware` or broken exception handlers in the `QueryExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `QueryExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `QueryExceptionHandlerMiddleware`. Both of these middleware have a default exception handler that will be used for anything not matching a registered handler. The default handler can be overwritten by registering a handler for the `Exception` class. @@ -15,6 +13,8 @@ Handlers are chosen by finding a handler for the most specific parent class of t ## Registering Exception Handlers +Exception handlers can be registered to `NameServer` and `RawNameSeerver` instances using either their `@exception_handler` decorators or their `register_exception_handler` methods. + ```python import dnslib from nserver import NameServer, Query, Response diff --git a/docs/index.md b/docs/index.md index 50c5ab4..ad02296 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,9 +14,6 @@ NServer has been built upon [dnslib](https://github.com/paulc/dnslib) however us NServer has been inspired by easy to use high level frameworks such as [Flask](https://github.com/pallets/flask) or [Requests](https://github.com/psf/requests). -!!! warning - NServer is currently Beta software and does not have complete documentation, testing, or implementation of certain features. - ## Features @@ -30,7 +27,7 @@ NServer has been inspired by easy to use high level frameworks such as [Flask](h Follow our [Quickstart Guide](quickstart.md). -```python title="TLDR" +```python title="tldr.py" from nserver import NameServer, Query, A server = NameServer("example") @@ -43,6 +40,9 @@ if __name__ == "__main__": server.run() ``` +```bash +nserver --server tldr.py:server +``` ## Bugs, Feature Requests etc Please [submit an issue on github](https://github.com/nhairs/nserver/issues). @@ -51,9 +51,6 @@ In the case of bug reports, please help us help you by following best practices In the case of feature requests, please provide background to the problem you are trying to solve so to help find a solution that makes the most sense for the library as well as your usecase. Before making a feature request consider looking at my (roughly written) [design notes](https://github.com/nhairs/nserver/blob/main/DESIGN_NOTES.md). -## Contributing -I am still working through open source licencing and contributing, so not taking PRs at this point in time. Instead raise and issue and I'll try get to it as soon a feasible. - ## Licence This project is licenced under the MIT Licence - see [`LICENCE`](https://github.com/nhairs/nserver/blob/main/LICENCE). diff --git a/docs/middleware.md b/docs/middleware.md index c54fc85..d5fb819 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -2,11 +2,9 @@ Middleware can be used to modify the behaviour of a server seperate to the individual rules that are registered to the server. Middleware is run on all requests and can modify both the input and response of a request. -Middleware requires `nserver>=2.0` - ## Middleware Stacks -Middleware operates in a stack with each middleware calling the middleware below it until one returns and the result is propagated back up the chain. NServer uses two stacks, the outmost stack deals with raw DNS records (`RawRecordMiddleware`), which will eventually convert the record to a `Query` which will then be passed to the main `QueryMiddleware` stack. +Middleware operates in a stack with each middleware calling the middleware below it until one returns and the result is propagated back up the chain. NServer uses two stacks, the outmost stack deals with raw DNS records (`RawMiddleware`), which will eventually convert the record to a `Query` which will then be passed to the main `QueryMiddleware` stack. Middleware can be added to the application until it is run. Once the server begins running the middleware cannot be modified. The ordering of middleware is kept in the order in which it is added to the server; that is the first middleware registered will be called before the second and so on. @@ -18,6 +16,8 @@ For most use cases you likely want to use [`QueryMiddleware`][nserver.middleware ### Registering `QueryMiddleware` +`QueryMiddleware` can be registered to `NameServer` instances using their `register_middleware` methods. + ```python from nserver import NameServer from nserver.middleware import QueryMiddleware @@ -38,7 +38,7 @@ from nserver import Query, Response class MyLoggingMiddleware(QueryMiddleware): def __init__(self, logging_name: str): super().__init__() - self.logger = logging.getLogger(f"my-awesome-app.{name}") + self.logger = logging.getLogger(f"my-awesome-app.{logging_name}") return def process_query( @@ -57,36 +57,39 @@ server.register_middleware(MyLoggingMiddleware("bar")) Once processed the `QueryMiddleware` stack will look as follows: -- [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] +- [`QueryExceptionHandlerMiddleware`][nserver.middleware.QueryExceptionHandlerMiddleware] - Customisable error handler for `Exception`s originating from within the stack. - `` - [`HookMiddleware`][nserver.middleware.HookMiddleware] - Runs hooks registered to the server. This can be considered a simplified version of middleware. -- [`RuleProcessor`][nserver.middleware.RuleProcessor] - - The entry point into our rule processing. -## `RawRecordMiddleware` +## `RawMiddleware` -[`RawRecordMiddleware`][nserver.middleware.RawRecordMiddleware] allows for modifying the raw `dnslib.DNSRecord`s that are recevied and sent by the server. +[`RawMiddleware`][nserver.middleware.RawMiddleware] allows for modifying the raw `dnslib.DNSRecord`s that are recevied and sent by the server. -### Registering `RawRecordMiddleware` +### Registering `RawMiddleware` + +`RawMiddleware` can be registered to `RawNameServer` instances using their `register_middleware` method. ```python # ... -from nserver.middleware import RawRecordMiddleware +from nserver import RawNameServer +from nserver.middleware import RawMiddleware + +raw_server = RawNameServer(server) -server.register_raw_middleware(RawRecordMiddleware()) +server.register_middleware(RawMiddleware()) ``` -### Creating your own `RawRecordMiddleware` +### Creating your own `RawMiddleware` -Using an unmodified `RawRecordMiddleware` isn't very interesting as it just passes the request onto the next middleware. To add your own middleware you should subclass `RawRecordMiddleware` and override the `process_record` method. +Using an unmodified `RawMiddleware` isn't very interesting as it just passes the request onto the next middleware. To add your own middleware you should subclass `RawMiddleware` and override the `process_record` method. ```python # ... -class SizeLimiterMiddleware(RawRecordMiddleware): +class SizeLimiterMiddleware(RawMiddleware): def __init__(self, max_size: int): super().__init__() self.max_size = max_size @@ -109,15 +112,13 @@ class SizeLimiterMiddleware(RawRecordMiddleware): return response -server.register_raw_middleware(SizeLimiterMiddleware(1400)) +server.register_middleware(SizeLimiterMiddleware(1400)) ``` -### Default `RawRecordMiddleware` stack +### Default `RawMiddleware` stack -Once processed the `RawRecordMiddleware` stack will look as follows: +Once processed the `RawMiddleware` stack will look as follows: -- [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] +- [`RawExceptionHandlerMiddleware`][nserver.middleware.RawExceptionHandlerMiddleware] - Customisable error handler for `Exception`s originating from within the stack. - `` -- [`QueryMiddlewareProcessor`][nserver.middleware.QueryMiddlewareProcessor] - - entry point into the `QueryMiddleware` stack. diff --git a/docs/quickstart.md b/docs/quickstart.md index 25d25fc..fa78f54 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -19,9 +19,6 @@ server = NameServer("example") @server.rule("example.com", ["A"]) def example_a_records(query: Query): return A(query.name, "1.2.3.4") - -if __name__ == "__main__": - server.run() ``` Here's what this code does: @@ -37,28 +34,25 @@ Here's what this code does: 4. When triggered our function will then return a single `A` record as a response. -5. Finally we add code so that we can run our server. - ### Running our server -With our server written we can now run it: +With our server written we can now run it using the `nserver` CLI: -```shell -python3 example_server.py +```bash +nserver --server path/to/minimal_server.py ``` - ```{.none .no-copy} -[INFO] Starting UDPv4Transport(address='localhost', port=9953) +[INFO] Starting UDPv4Transport(address='localhost', port=5300) ``` We can access it using `dig`. ```shell -dig -p 9953 @localhost A example.com +dig -p 5300 @localhost A example.com ``` ```{.none .no-copy} -; <<>> DiG 9.18.12-0ubuntu0.22.04.3-Ubuntu <<>> -p 9953 @localhost A example.com +; <<>> DiG 9.18.12-0ubuntu0.22.04.3-Ubuntu <<>> -p 5300 @localhost A example.com ; (1 server found) ;; global options: +cmd ;; Got answer: @@ -72,7 +66,7 @@ dig -p 9953 @localhost A example.com example.com. 300 IN A 1.2.3.4 ;; Query time: 324 msec -;; SERVER: 127.0.0.1#9953(localhost) (UDP) +;; SERVER: 127.0.0.1#5300(localhost) (UDP) ;; WHEN: Thu Nov 02 21:27:12 AEDT 2023 ;; MSG SIZE rcvd: 45 ``` diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000..d2974bb --- /dev/null +++ b/docs/security.md @@ -0,0 +1,14 @@ +# Security Policy + +## Supported Versions + +Security support for Python JSON Logger is provided for all [security supported versions of Python](https://endoflife.date/python) and for unsupported versions of Python where [recent downloads over the last 90 days exceeds 10% of all downloads](https://pypistats.org/packages/nserver). + + +As of 2024-11-22 security support is provided for Python versions `3.8+`. + + +## Reporting a Vulnerability + +Please report vulnerabilties [using GitHub](https://github.com/nhairs/nserver/security/advisories/new). + diff --git a/docs/subserver-blueprint.md b/docs/subserver-blueprint.md new file mode 100644 index 0000000..2323222 --- /dev/null +++ b/docs/subserver-blueprint.md @@ -0,0 +1,87 @@ +# Sub-Servers and Blueprints + +## Sub-Servers + +To allow for composing an application into different parts, a [`NameServer`][nserver.server.NameServer] can be included in another `NameServer`. + +Use cases: + +- Split up your application across different servers for maintainability / composability. +- Reuse a server registered under different rules. +- Allow custom packages to define their own rules that you can add to your own server. + +### Using Sub-Servers + +```python +from nserver import NameServer, ZoneRule, ALL_CTYPES, A, TXT + +# First child NameServer +mysite = NameServer("mysite") + +@mysite.rule("nicholashairs.com", ["A"]) +@mysite.rule("www.nicholashairs.com", ["A"]) +def nicholashairs_website(query: Query) -> A: + return A(query.name, "159.65.13.73") + +@mysite.rule(ZoneRule, "", ALL_CTYPES) +def nicholashairs_catchall(query: Query) -> None: + # Return empty response for all other queries + return None + +# Second child NameServer +en_subserver = NameServer("english-speaking-blueprint") + +@en_subserver.rule("hello.{base_domain}", ["TXT"]) +def en_hello(query: Query) -> TXT: + return TXT(query.name, "Hello There!") + +# Register to main NameServer +server = NameServer("server") +server.register_subserver(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "au", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "nz", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "uk", ALL_CTYPES) +``` + +#### Middleware, Hooks, and Exception Handling + +Don't forget that each `NameServer` maintains it's own middleware stack, exception handlers, and hooks. + +In particular errors will not propagate up from a child server to it's parent as the child's exception handler will catch any exception and return a response. + +## Blueprints + +[`Blueprint`][nserver.server.Blueprint]s act as a container for rules. They are an efficient way to compose your application if you do not want or need to use functionality provided by a `QueryMiddleware` stack. + +### Using Blueprints + +```python +# ... +from nserver import Blueprint, MX + +no_email_blueprint = Blueprint("noemail") + +@no_email_blueprint.rule("{base_domain}", ["MX"]) +@no_email_blueprint.rule("**.{base_domain}", ["MX"]) +def no_email(query: Query) -> MX: + "Indicate that we do not have a mail exchange" + return MX(query.name, ".", 0) + + +## Add it to our sub-servers +en_subserver.register_rule(no_email_blueprint) + +# Problem! Because we have already registered the nicholashairs_catchall rule, +# it will prevent our blueprint from being called. So instead let's manually +# insert it as the first rule. +mysite.rules.insert(0, no_email_blueprint) +``` + +### Key differences with `NameServer` + +- Only provides the `@rule` decorator and `register_rule` method. + - It does not have a `QueryMiddleware` stack which means it does not support hooks or error-handling. +- Is used directly in `register_rule` (e.g. `some_server.register_rule(my_blueprint)`). +- If rule does not match an internal rule will continue to the next rule in the parent server. + + In comparison `NameServer` instances will return `NXDOMAIN` if a rule doesn't match their internal rules. diff --git a/lib/python/build.Dockerfile b/lib/python/build.Dockerfile deleted file mode 100644 index c25011a..0000000 --- a/lib/python/build.Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM python:3.7 - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN mkdir -p /code/src \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shel /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su - devuser -c "pip install --user --upgrade pip" - -## ^^ copied from common.Dockerfile - try to keep in sync fo caching - -# Base stuff -ADD . /code - -RUN chown -R ${SOURCE_UID_GID} /code # needed twice because added files - -RUN ls -lah /code - -RUN su - devuser -c "cd /code && pip install --user build" - -CMD echo "docker-compose build python-build complete 🎉" diff --git a/lib/python/build.sh b/lib/python/build.sh index c4d7bcc..48c0405 100755 --- a/lib/python/build.sh +++ b/lib/python/build.sh @@ -47,22 +47,7 @@ replace_version_var BUILD_DATETIME "${BUILD_DATETIME}" 0 head -n 22 "src/${PACKAGE_PYTHON_NAME}/_version.py" | tail -n 7 -if [ "$PYTHON_PACKAGE_REPOSITORY" == "testpypi" ]; then - echo "MODIFYING PACKAGE_NAME" - # Replace name suitable for test.pypi.org - # https://packaging.python.org/tutorials/packaging-projects/#creating-setup-py - sed -i "s/^PACKAGE_NAME = .*/PACKAGE_NAME = \"${PACKAGE_NAME}-${TESTPYPI_USERNAME}\"/" setup.py - grep "^PACKAGE_NAME = " setup.py - - mv "src/${PACKAGE_PYTHON_NAME}" "src/${PACKAGE_PYTHON_NAME}_$(echo -n $TESTPYPI_USERNAME | tr '-' '_')" -fi - -if [[ "$GIT_BRANCH" != "master" && "$GIT_BRANCH" != "main" ]]; then - sed -i "s/^PACKAGE_VERSION = .*/PACKAGE_VERSION = \"${BUILD_VERSION}\"/" setup.py - grep "^PACKAGE_VERSION = " setup.py -fi - ## Build ## ----------------------------------------------------------------------------- -#python3 setup.py bdist_wheel -python3 -m build --wheel +uv build +git restore src/${PACKAGE_PYTHON_NAME}/_version.py diff --git a/lib/python/common.Dockerfile b/lib/python/common.Dockerfile deleted file mode 100644 index 668c098..0000000 --- a/lib/python/common.Dockerfile +++ /dev/null @@ -1,26 +0,0 @@ -# syntax = docker/dockerfile:1.2 -FROM python:3.7 - - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN apt update && apt install -y \ - less - -RUN mkdir -p /code/src \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shell /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su -l devuser -c "pip install --user --upgrade pip" - -ADD pyproject.toml /code -RUN chown -R ${SOURCE_UID_GID} /code # needed twice because added files - -RUN ls -lah /code /home /home/devuser /home/devuser/.cache /home/devuser/.cache/pip - -RUN --mount=type=cache,target=/home/devuser/.cache,uid=1000,gid=1000 \ - su -l devuser -c "cd /code && pip install --user -e .[dev,docs]" - -CMD echo "docker-compose build python-common complete 🎉" diff --git a/lib/python/install_pypy.sh b/lib/python/install_pypy.sh deleted file mode 100755 index c3547f0..0000000 --- a/lib/python/install_pypy.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -set -e - -PYPY_VERSION="7.3.9" -PYTHON_VERSIONS="3.7 3.8 3.9" - -# Note: pypy-7.3.9 is last version to support python3.7 - -if [ ! -d /tmp/pypy ]; then - mkdir /tmp/pypy -fi - -cd /tmp/pypy - -for PYTHON_VERSION in $PYTHON_VERSIONS; do - FULLNAME="pypy${PYTHON_VERSION}-v${PYPY_VERSION}-linux64" - FILENAME="${FULLNAME}.tar.bz2" - - if [ ! -f "${FILENAME}" ]; then - # not cached - fetch - echo "Fetching ${FILENAME}" - wget -q "https://downloads.python.org/pypy/${FILENAME}" - fi - - echo "Extracting ${FILENAME} to /opt/${FULLNAME}" - tar xf ${FILENAME} --directory=/opt - - echo "Removing temp file" - rm -f ${FILENAME} - - echo "sanity check" - ls /opt - - echo "Linking ${FULLNAME}/bin/pypy${PYTHON_VERSION} to /usr/bin" - ln -s "/opt/${FULLNAME}/bin/pypy${PYTHON_VERSION}" /usr/bin/ - - echo "" - -done diff --git a/lib/python/tox.Dockerfile b/lib/python/tox.Dockerfile deleted file mode 100644 index e8ad7c9..0000000 --- a/lib/python/tox.Dockerfile +++ /dev/null @@ -1,50 +0,0 @@ -FROM ubuntu:20.04 - -# We use deadsnakes ppa to install -# https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa -# -# As noted in the readme, 22.04 supports only 3.7+, so use 20.04 to support some older versions -# This also means we don't install 3.8 as it is already provided - -# TZ https://serverfault.com/a/1016972 -ARG DEBIAN_FRONTEND=noninteractive -ENV TZ=Etc/UTC - -RUN --mount=target=/var/lib/apt/lists,type=cache,sharing=locked \ - --mount=target=/var/cache/apt,type=cache,sharing=locked \ - rm -f /etc/apt/apt.conf.d/docker-clean \ - && apt update \ - && apt upgrade --yes \ - && apt install --yes software-properties-common wget python3-pip\ - && add-apt-repository ppa:deadsnakes/ppa \ - && apt update --yes - -RUN --mount=target=/var/lib/apt/lists,type=cache,sharing=locked \ - --mount=target=/var/cache/apt,type=cache,sharing=locked \ - apt install --yes \ - python3.6 python3.6-dev python3.6-distutils \ - python3.7 python3.7-dev python3.7-distutils \ - python3.9 python3.9-dev python3.9-distutils \ - python3.10 python3.10-dev python3.10-distutils \ - python3.11 python3.11-dev python3.11-distutils \ - python3.12 python3.12-dev python3.12-distutils - -## pypy -ADD lib/python/install_pypy.sh /tmp -RUN --mount=target=/tmp/pypy,type=cache,sharing=locked \ - /tmp/install_pypy.sh - - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN mkdir -p /code/dist /code/tests \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shell /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su - devuser -c "pip install --user --upgrade pip" - -RUN su - devuser -c "pip install --user tox" - -CMD echo "docker-compose build python-tox complete 🎉" diff --git a/mkdocs.yml b/mkdocs.yml index fa88c1d..af68cf4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,12 +11,15 @@ watch: nav: - "Home": index.md - quickstart.md + - architecture.md - middleware.md - error-handling.md - - blueprints.md + - subserver-blueprint.md - production-deployment.md - - changelog.md - external-resources.md + - changelog.md + - security.md + - contributing.md - API Reference: - ... | reference/nserver/* @@ -84,10 +87,10 @@ plugins: python: paths: - src - #import: - # - https://docs.python.org/3/objects.inv - # - https://mkdocstrings.github.io/objects.inv - # - https://mkdocstrings.github.io/griffe/objects.inv + import: + - https://docs.python.org/3/objects.inv + - https://mkdocstrings.github.io/objects.inv + - https://mkdocstrings.github.io/griffe/objects.inv options: filters: - "!^_" diff --git a/pylintrc b/pylintrc index 776bfe6..2dd7180 100644 --- a/pylintrc +++ b/pylintrc @@ -456,7 +456,7 @@ preferred-modules= # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, - setUp, + setup, __post_init__ # List of member names, which should be excluded from the protected access @@ -479,6 +479,9 @@ valid-metaclass-classmethod-first-arg=cls # Maximum number of arguments for function / method. max-args=10 +# Max number of positional arguments for a function / method +max-positional-arguments=8 + # Maximum number of attributes for a class (see R0902). max-attributes=15 diff --git a/pyproject.toml b/pyproject.toml index 7c9a83b..6dd432d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,17 +4,19 @@ build-backend = "setuptools.build_meta" [project] name = "nserver" -version = "2.0.0" +version = "3.0.0.dev1" description = "DNS Name Server Framework" authors = [ {name = "Nicholas Hairs", email = "info+nserver@nicholashairs.com"}, ] # Dependency Information -requires-python = ">=3.7" +requires-python = ">=3.8" dependencies = [ "dnslib", + "pillar~=0.3", "tldextract", + "typing-extensions;python_version<'3.10'", ] # Extra information @@ -23,12 +25,12 @@ license = {text = "MIT"} classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Typing :: Typed", "Topic :: Internet", @@ -36,17 +38,13 @@ classifiers = [ ] [project.urls] -homepage = "https://nhairs.github.io/nserver/latest/" -github = "https://github.com/nhairs/nserver" +HomePage = "https://nhairs.github.io/nserver" +GitHub = "https://github.com/nhairs/nserver" [project.optional-dependencies] -build = [ - "setuptools", - "wheel", -] - dev = [ - ### dev.sh dependencies + "tox", + "tox-uv", ## Formatting / Linting "validate-pyproject[all]", "black", @@ -54,11 +52,10 @@ dev = [ "mypy", ## Testing "pytest", - ## REPL - "bpython", -] - -docs = [ + ## Build + "setuptools", + "wheel", + ## Docs "black", "mkdocs", "mkdocs-material>=8.5", @@ -70,5 +67,11 @@ docs = [ "mike", ] +[project.scripts] +nserver = "nserver.__main__:main" + [tool.setuptools.package-data] nserver = ["py.typed"] + +[tool.black] +line-length = 100 diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index 54e8c75..dde0054 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,5 +1,6 @@ +### IMPORTS +### ============================================================================ from .models import Query, Response from .rules import ALL_QTYPES, StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA -from .server import NameServer, Blueprint -from .settings import Settings +from .server import NameServer, RawNameServer, Blueprint diff --git a/src/nserver/__main__.py b/src/nserver/__main__.py new file mode 100644 index 0000000..4039844 --- /dev/null +++ b/src/nserver/__main__.py @@ -0,0 +1,18 @@ +### IMPORTS +### ============================================================================ +from .cli import CliApplication + + +### FUNCTIONS +### ============================================================================ +def main(): + "CLI Entrypoint" + app = CliApplication() + app.run() + return app + + +### MAIN +### ============================================================================ +if __name__ == "__main__": + main() diff --git a/src/nserver/_version.py b/src/nserver/_version.py index 24869af..722c41c 100644 --- a/src/nserver/_version.py +++ b/src/nserver/_version.py @@ -1,4 +1,5 @@ """Version information for this package.""" + ### IMPORTS ### ============================================================================ ## Standard Library diff --git a/src/nserver/application.py b/src/nserver/application.py new file mode 100644 index 0000000..195bc32 --- /dev/null +++ b/src/nserver/application.py @@ -0,0 +1,109 @@ +### IMPORTS +### ============================================================================ +## Future +from __future__ import annotations + +## Standard Library + +## Installed +from pillar.logging import LoggingMixin + +## Application +from .exceptions import InvalidMessageError +from .server import NameServer, RawNameServer +from .transport import TransportBase + + +### CLASSES +### ============================================================================ +class BaseApplication(LoggingMixin): + """Base class for all application classes. + + New in `3.0`. + """ + + def __init__(self, server: NameServer | RawNameServer) -> None: + if isinstance(server, NameServer): + server = RawNameServer(server) + self.server: RawNameServer = server + self.logger = self.get_logger() + return + + def run(self) -> int | None: + """Run this application. + + Child classes must override this method. + + Returns: + Integer status code to be returned. `None` will be treated as `0`. + """ + raise NotImplementedError() + + +class DirectApplication(BaseApplication): + """Application that directly runs the server. + + New in `3.0`. + """ + + MAX_ERRORS: int = 10 + + exit_code: int + + def __init__(self, server: NameServer | RawNameServer, transport: TransportBase) -> None: + super().__init__(server) + self.transport = transport + self.exit_code = 0 + self.shutdown_server = False + return + + def run(self) -> int: + """Start running the server + + Returns: + `exit_code`, `0` if exited normally + """ + # Start Server + # TODO: Do we want to recreate the transport instance or do we assume that + # transport.shutdown_server puts it back into a ready state? + # We could make this configurable? :thonking: + + self.info(f"Starting {self.transport}") + try: + self.transport.start_server() + except Exception as e: # pylint: disable=broad-except + self.critical(f"Failed to start server. {e}", exc_info=e) + self.exit_code = 1 + return self.exit_code + + # Process Requests + error_count = 0 + while True: + if self.shutdown_server: + break + + try: + message = self.transport.receive_message() + message.response = self.server.process_request(message.message) + self.transport.send_message_response(message) + + except InvalidMessageError as e: + self.warning(f"{e}") + + except Exception as e: # pylint: disable=broad-except + self.error(f"Uncaught error occured. {e}", exc_info=e) + error_count += 1 + if self.MAX_ERRORS and error_count >= self.MAX_ERRORS: + self.critical(f"Max errors hit ({error_count})") + self.shutdown_server = True + self.exit_code = 1 + + except KeyboardInterrupt: + self.info("KeyboardInterrupt received.") + self.shutdown_server = True + + # Stop Server + self.info("Shutting down server") + self.transport.stop_server() + + return self.exit_code diff --git a/src/nserver/cli.py b/src/nserver/cli.py new file mode 100644 index 0000000..88da49b --- /dev/null +++ b/src/nserver/cli.py @@ -0,0 +1,134 @@ +### IMPORTS +### ============================================================================ +## Future +from __future__ import annotations + +## Standard Library +import argparse +import importlib +import os +import pydoc + +## Installed +import pillar.application + +## Application +from . import transport +from . import _version + +from .application import BaseApplication, DirectApplication +from .server import NameServer, RawNameServer + + +### CLASSES +### ============================================================================ +class CliApplication(pillar.application.Application): + """NServer CLI tool for running servers""" + + application_name = "nserver" + name = "nserver" + version = _version.VERSION_INFO_FULL + epilog = "For full information including licence see https://github.com/nhairs/nserver" + + config_args_enabled = False + + def get_argument_parser(self) -> argparse.ArgumentParser: + parser = super().get_argument_parser() + + ## Server + ## --------------------------------------------------------------------- + parser.add_argument( + "--server", + action="store", + required=True, + help=( + "Import path of server / factory to run in the form of " + "package.module.path:attribute" + ), + ) + + ## Transport + ## --------------------------------------------------------------------- + parser.add_argument( + "--host", + action="store", + default="localhost", + help="Host (IP) to bind to. Defaults to localhost.", + ) + + parser.add_argument( + "--port", + action="store", + default=5300, + type=int, + help="Port to bind to. Defaults to 5300.", + ) + + transport_group = parser.add_mutually_exclusive_group() + transport_group.add_argument( + "--udp", + action="store_const", + const=transport.UDPv4Transport, + dest="transport", + help="Use UDPv4 socket for transport. (default)", + ) + transport_group.add_argument( + "--udp6", + action="store_const", + const=transport.UDPv6Transport, + dest="transport", + help="Use UDPv6 socket for transport.", + ) + transport_group.add_argument( + "--tcp", + action="store_const", + const=transport.TCPv4Transport, + dest="transport", + help="Use TCPv4 socket for transport.", + ) + + parser.set_defaults(transport=transport.UDPv4Transport) + return parser + + def setup(self, *args, **kwargs) -> None: + super().setup(*args, **kwargs) + + self.server = self.get_server() + self.application = self.get_application() + return + + def main(self) -> int | None: + return self.application.run() + + def get_server(self) -> NameServer | RawNameServer: + """Factory for getting the server to run based on current settings""" + module_path, attribute_path = self.args.server.split(":") + + obj: object + if os.path.isfile(module_path): + # Ref: https://stackoverflow.com/a/68361215/12281814 + obj = pydoc.importfile(module_path) + else: + obj = importlib.import_module(module_path) + + for attribute_name in attribute_path.split("."): + obj = getattr(obj, attribute_name) + + if isinstance(obj, (NameServer, RawNameServer)): + return obj + + # Assume callable (will throw error if not) + server = obj() # type: ignore[operator] + + if isinstance(server, (NameServer, RawNameServer)): + return server + + raise TypeError(f"Imported factory ({obj}) did not return a server ({server})") + + def get_application(self) -> BaseApplication: + """Factory for getting the application based on current settings""" + application = DirectApplication( + self.server, + self.args.transport(self.args.host, self.args.port), + ) + return application diff --git a/src/nserver/exceptions.py b/src/nserver/exceptions.py index c72946d..89cc2ca 100644 --- a/src/nserver/exceptions.py +++ b/src/nserver/exceptions.py @@ -1,11 +1,11 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import base64 -# Note: Union can only be replaced with `X | Y` in 3.10+ -from typing import Tuple, Union - ## Installed ## Application @@ -17,7 +17,7 @@ class InvalidMessageError(ValueError): """An invalid DNS message""" def __init__( - self, error: Exception, raw_data: bytes, remote_address: Union[str, Tuple[str, int]] + self, error: Exception, raw_data: bytes, remote_address: str | tuple[str, int] ) -> None: """ Args: diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 2662bd7..e27df48 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -1,163 +1,220 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import inspect import threading -from typing import Callable, Dict, List, Type, Optional +from typing import TYPE_CHECKING, Callable, Generic, TypeVar +import sys ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application from .models import Query, Response -from .records import RecordBase -from .rules import RuleBase, RuleResult +from .rules import coerce_to_response, RuleResult + +## Special +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias ### CONSTANTS ### ============================================================================ +# pylint: disable=invalid-name +T_request = TypeVar("T_request") +T_response = TypeVar("T_response") +# pylint: enable=invalid-name + ## Query Middleware -QueryMiddlewareCallable = Callable[[Query], Response] +## ----------------------------------------------------------------------------- +QueryCallable: TypeAlias = Callable[[Query], Response] """Type alias for functions that can be used with `QueryMiddleware.next_function`""" -ExceptionHandler = Callable[[Query, Exception], Response] +QueryExceptionHandler: TypeAlias = Callable[[Query, Exception], Response] """Type alias for `ExceptionHandlerMiddleware` exception handler functions""" # Hooks -BeforeFirstQueryHook = Callable[[], None] +BeforeFirstQueryHook: TypeAlias = Callable[[], None] """Type alias for `HookMiddleware.before_first_query` functions.""" -BeforeQueryHook = Callable[[Query], RuleResult] +BeforeQueryHook: TypeAlias = Callable[[Query], RuleResult] """Type alias for `HookMiddleware.before_query` functions.""" -AfterQueryHook = Callable[[Response], Response] +AfterQueryHook: TypeAlias = Callable[[Response], Response] """Type alias for `HookMiddleware.after_query` functions.""" ## RawRecordMiddleware -RawRecordMiddlewareCallable = Callable[[dnslib.DNSRecord], dnslib.DNSRecord] -"""Type alias for functions that can be used with `RawRecordMiddleware.next_function`""" - -RawRecordExceptionHandler = Callable[[dnslib.DNSRecord, Exception], dnslib.DNSRecord] -"""Type alias for `RawRecordExceptionHandlerMiddleware` exception handler functions""" - - -### FUNCTIONS -### ============================================================================ -def coerce_to_response(result: RuleResult) -> Response: - """Convert some `RuleResult` to a `Response` - - New in `2.0`. - - Args: - result: the results to convert - - Raises: - TypeError: unsupported result type - """ - if isinstance(result, Response): - return result +## ----------------------------------------------------------------------------- +if TYPE_CHECKING: - if result is None: - return Response() + class RawRecord(dnslib.DNSRecord): + "Dummy class for type checking as dnslib is not typed" - if isinstance(result, RecordBase) and result.__class__ is not RecordBase: - return Response(answers=result) +else: + RawRecord: TypeAlias = dnslib.DNSRecord + """Type alias for raw records to allow easy changing of implementation details""" - if isinstance(result, list) and all(isinstance(item, RecordBase) for item in result): - return Response(answers=result) +RawMiddlewareCallable: TypeAlias = Callable[[RawRecord], RawRecord] +"""Type alias for functions that can be used with `RawRecordMiddleware.next_function`""" - raise TypeError(f"Cannot process result: {result!r}") +RawExceptionHandler: TypeAlias = Callable[[RawRecord, Exception], RawRecord] +"""Type alias for `RawRecordExceptionHandlerMiddleware` exception handler functions""" ### CLASSES ### ============================================================================ -## Request Middleware +## Generic Base Classes ## ----------------------------------------------------------------------------- -class QueryMiddleware: - """Middleware for interacting with `Query` objects +class MiddlewareBase(Generic[T_request, T_response], LoggingMixin): + """Generic base class for middleware classes. - New in `2.0`. + New in `3.0`. """ def __init__(self) -> None: - self.next_function: Optional[QueryMiddlewareCallable] = None + self.next_function: Callable[[T_request], T_response] | None = None + self.logger = self.get_logger() return - def __call__(self, query: Query) -> Response: + def __call__(self, request: T_request) -> T_response: + """Call this middleware + + Args: + request: request to process + + Raises: + RuntimeError: If `next_function` is not set. + """ + if self.next_function is None: - raise RuntimeError("next_function is not set") - return self.process_query(query, self.next_function) + raise RuntimeError("next_function is not set. Need to call register_next_function.") + return self.process_request(request, self.next_function) + + def set_next_function(self, next_function: Callable[[T_request], T_response]) -> None: + """Set the `next_function` of this middleware - def register_next_function(self, next_function: QueryMiddlewareCallable) -> None: - """Set the `next_function` of this middleware""" + Args: + next_function: Callable that this middleware should call next. + """ if self.next_function is not None: - raise RuntimeError("next_function is already set") + raise RuntimeError(f"next_function is already set to {self.next_function}") self.next_function = next_function return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Handle an incoming query. + def process_request( + self, request: T_request, call_next: Callable[[T_request], T_response] + ) -> T_response: + """Process a given request - Child classes should override this function (if they do not this middleware will - simply pass the query onto the next function). - - Args: - query: the incoming query - call_next: the next function in the chain + Child classes should override this method with their own logic. """ - return call_next(query) + return call_next(request) -class ExceptionHandlerMiddleware(QueryMiddleware): - """Middleware for handling exceptions originating from a `QueryMiddleware` stack. - - Allows registering handlers for individual `Exception` types. Only one handler can - exist for a given `Exception` type. - - When an exception is encountered, the middleware will search for the first handler that - matches the class or parent class of the exception in method resolution order. If no handler - is registered will use this classes `self.default_exception_handler`. - - New in `2.0`. +class ExceptionHandlerBase(MiddlewareBase[T_request, T_response]): + """Generic base class for middleware exception handlers Attributes: - exception_handlers: registered exception handlers + handlers: registered exception handlers + + New in `3.0`. """ def __init__( - self, exception_handlers: Optional[Dict[Type[Exception], ExceptionHandler]] = None + self, + handlers: dict[type[Exception], Callable[[T_request, Exception], T_response]] | None = None, ) -> None: - """ - Args: - exception_handlers: exception handlers to assign - """ super().__init__() - self.exception_handlers = exception_handlers if exception_handlers is not None else {} + self.handlers: dict[type[Exception], Callable[[T_request, Exception], T_response]] = ( + handlers if handlers is not None else {} + ) return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Call the next function catching any handling any errors""" + def process_request(self, request, call_next): + """Call the next function handling any exceptions that arise""" try: - response = call_next(query) + response = call_next(request) except Exception as e: # pylint: disable=broad-except - handler = self.get_exception_handler(e) - response = handler(query, e) + handler = self.get_handler(e) + response = handler(request, e) return response - def get_exception_handler(self, exception: Exception) -> ExceptionHandler: - """Get the exception handler for an `Exception`. + def set_handler( + self, + exception_class: type[Exception], + handler: Callable[[T_request, Exception], T_response], + *, + allow_overwrite: bool = False, + ) -> None: + """Add an exception handler for the given exception class + + Args: + exception_class: Exceptions to associate with this handler. + handler: The handler to add. + allow_overwrite: Allow overwriting existing handlers. + + Raises: + ValueError: If a handler already exists for the given exception and + `allow_overwrite` is `False`. + """ + if exception_class in self.handlers and not allow_overwrite: + raise ValueError( + f"Exception handler already exists for {exception_class} and allow_overwrite is False" + ) + self.handlers[exception_class] = handler + return + + def get_handler(self, exception: Exception) -> Callable[[T_request, Exception], T_response]: + """Get the exception handler for the given exception Args: exception: the exception we wish to handle """ for class_ in inspect.getmro(exception.__class__): - if class_ in self.exception_handlers: - return self.exception_handlers[class_] + if class_ in self.handlers: + return self.handlers[class_] # No exception handler found - use default handler - return self.default_exception_handler + return self.default_handler + + @staticmethod + def default_handler(request: T_request, exception: Exception) -> T_response: + """Default exception handler + + Child classes MUST override this method. + """ + raise NotImplementedError("Must overide this method") + + +## Request Middleware +## ----------------------------------------------------------------------------- +class QueryMiddleware(MiddlewareBase[Query, Response]): + """Middleware for interacting with `Query` objects + + New in `3.0`. + """ + + +class QueryExceptionHandlerMiddleware(ExceptionHandlerBase[Query, Response], QueryMiddleware): + """Middleware for handling exceptions originating from a `QueryMiddleware` stack. + + Allows registering handlers for individual `Exception` types. Only one handler can + exist for a given `Exception` type. + + When an exception is encountered, the middleware will search for the first handler that + matches the class or parent class of the exception in method resolution order. If no handler + is registered will use this classes `self.default_exception_handler`. + + New in `3.0`. + """ @staticmethod - def default_exception_handler(query: Query, exception: Exception) -> Response: + def default_handler(request: Query, exception: Exception) -> Response: """The default exception handler""" # pylint: disable=unused-argument return Response(error_code=dnslib.RCODE.SERVFAIL) @@ -182,21 +239,21 @@ class HookMiddleware(QueryMiddleware): hook or from the next function in the middleware chain. They take a `Response` input and must return a `Response`. - New in `2.0`. - Attributes: before_first_query: `before_first_query` hooks before_query: `before_query` hooks after_query: `after_query` hooks before_first_query_run: have we run the `before_first_query` hooks before_first_query_failed: did any `before_first_query` hooks fail + + New in `3.0`. """ def __init__( self, - before_first_query: Optional[List[BeforeFirstQueryHook]] = None, - before_query: Optional[List[BeforeQueryHook]] = None, - after_query: Optional[List[AfterQueryHook]] = None, + before_first_query: list[BeforeFirstQueryHook] | None = None, + before_query: list[BeforeQueryHook] | None = None, + after_query: list[AfterQueryHook] | None = None, ) -> None: """ Args: @@ -205,26 +262,24 @@ def __init__( after_query: initial `after_query` hooks to register """ super().__init__() - self.before_first_query: List[BeforeFirstQueryHook] = ( + self.before_first_query: list[BeforeFirstQueryHook] = ( before_first_query if before_first_query is not None else [] ) - self.before_query: List[BeforeQueryHook] = before_query if before_query is not None else [] - self.after_query: List[AfterQueryHook] = after_query if after_query is not None else [] + self.before_query: list[BeforeQueryHook] = before_query if before_query is not None else [] + self.after_query: list[AfterQueryHook] = after_query if after_query is not None else [] self.before_first_query_run: bool = False self.before_first_query_failed: bool = False self._before_first_query_lock = threading.Lock() return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Process a query running relevant hooks.""" + def process_request(self, request: Query, call_next: QueryCallable) -> Response: with self._before_first_query_lock: if not self.before_first_query_run: - # self._debug("Running before_first_query") self.before_first_query_run = True try: for before_first_query_hook in self.before_first_query: - # self._vdebug(f"Running before_first_query func: {hook}") + self.vdebug(f"Running before_first_query_hook: {before_first_query_hook}") before_first_query_hook() except Exception: self.before_first_query_failed = True @@ -233,92 +288,34 @@ def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Res result: RuleResult for before_query_hook in self.before_query: - result = before_query_hook(query) + self.vdebug(f"Running before_query_hook: {before_query_hook}") + result = before_query_hook(request) if result is not None: - # self._debug(f"Got result from before_hook: {hook}") + self.debug(f"Got result from before_query_hook: {before_query_hook}") break else: # No before query hooks returned a response - keep going - result = call_next(query) + result = call_next(request) response = coerce_to_response(result) for after_query_hook in self.after_query: + self.vdebug(f"Running after_query_hook: {after_query_hook}") response = after_query_hook(response) return response -# Final callable -# .............................................................................. -# This is not a QueryMiddleware - it is however the end of the line for all QueryMiddleware -class RuleProcessor: - """Find and run a matching rule function. - - This class serves as the bottom of the `QueryMiddleware` stack. - - New in `2.0`. - """ - - def __init__(self, rules: List[RuleBase]) -> None: - """ - Args: - rules: rules to run against - """ - self.rules = rules - return - - def __call__(self, query: Query) -> Response: - for rule in self.rules: - rule_func = rule.get_func(query) - if rule_func is not None: - # self._info(f"Matched Rule: {rule}") - return coerce_to_response(rule_func(query)) - - # self._info("Did not match any rule") - return Response(error_code=dnslib.RCODE.NXDOMAIN) - - ## Raw Middleware ## ----------------------------------------------------------------------------- -class RawRecordMiddleware: +class RawMiddleware(MiddlewareBase[RawRecord, RawRecord]): """Middleware to be run against raw `dnslib.DNSRecord`s. - New in `2.0`. + New in `3.0`. """ - def __init__(self) -> None: - self.next_function: Optional[RawRecordMiddlewareCallable] = None - return - - def __call__(self, record: dnslib.DNSRecord) -> None: - if self.next_function is None: - raise RuntimeError("next_function is not set") - return self.process_record(record, self.next_function) - - def register_next_function(self, next_function: RawRecordMiddlewareCallable) -> None: - """Set the `next_function` of this middleware""" - if self.next_function is not None: - raise RuntimeError("next_function is already set") - self.next_function = next_function - return - - def process_record( - self, record: dnslib.DNSRecord, call_next: RawRecordMiddlewareCallable - ) -> dnslib.DNSRecord: - """Handle an incoming record. - - Child classes should override this function (if they do not this middleware will - simply pass the record onto the next function). - Args: - record: the incoming record - call_next: the next function in the chain - """ - return call_next(record) - - -class RawRecordExceptionHandlerMiddleware(RawRecordMiddleware): +class RawExceptionHandlerMiddleware(ExceptionHandlerBase[RawRecord, RawRecord]): """Middleware for handling exceptions originating from a `RawRecordMiddleware` stack. Allows registering handlers for individual `Exception` types. Only one handler can @@ -326,109 +323,36 @@ class RawRecordExceptionHandlerMiddleware(RawRecordMiddleware): When an exception is encountered, the middleware will search for the first handler that matches the class or parent class of the exception in method resolution order. If no handler - is registered will use this classes `self.default_exception_handler`. + is registered will use this classes `self.default_handler`. Danger: Important Exception handlers are expected to be robust - that is, they must always return correctly even if they internally encounter an `Exception`. - New in `2.0`. - Attributes: - exception_handlers: registered exception handlers - """ - - def __init__( - self, exception_handlers: Optional[Dict[Type[Exception], RawRecordExceptionHandler]] = None - ) -> None: - super().__init__() - self.exception_handlers: Dict[Type[Exception], RawRecordExceptionHandler] = ( - exception_handlers if exception_handlers is not None else {} - ) - return - - def process_record( - self, record: dnslib.DNSRecord, call_next: RawRecordMiddlewareCallable - ) -> dnslib.DNSRecord: - """Call the next function handling any exceptions that arise""" - try: - response = call_next(record) - except Exception as e: # pylint: disable=broad-except - handler = self.get_exception_handler(e) - response = handler(record, e) - return response - - def get_exception_handler(self, exception: Exception) -> RawRecordExceptionHandler: - """Get the exception handler for the given exception + handlers: registered exception handlers - Args: - exception: the exception we wish to handle - """ - for class_ in inspect.getmro(exception.__class__): - if class_ in self.exception_handlers: - return self.exception_handlers[class_] - # No exception handler found - use default handler - return self.default_exception_handler + New in `3.0`. + """ @staticmethod - def default_exception_handler( - record: dnslib.DNSRecord, exception: Exception - ) -> dnslib.DNSRecord: + def default_handler(request: RawRecord, exception: Exception) -> RawRecord: """Default exception handler""" # pylint: disable=unused-argument - response = record.reply() + response = request.reply() response.header.rcode = dnslib.RCODE.SERVFAIL return response -# Final Callable -# .............................................................................. -# This is not a RawRcordMiddleware - it is however the end of the line for all RawRecordMiddleware -class QueryMiddlewareProcessor: - """Convert an incoming DNS record and pass it to a `QueryMiddleware` stack. - - This class serves as the bottom of the `RawRcordMiddleware` stack. - - New in `2.0`. - """ - - def __init__(self, query_middleware: QueryMiddlewareCallable) -> None: - """ - Args: - query_middleware: the top of the middleware stack - """ - self.query_middleware = query_middleware - return - - def __call__(self, record: dnslib.DNSRecord) -> dnslib.DNSRecord: - response = record.reply() - - if record.header.opcode != dnslib.OPCODE.QUERY: - # self._info(f"Received non-query opcode: {record.header.opcode}") - # This server only response to DNS queries - response.header.rcode = dnslib.RCODE.NOTIMP - return response - - if len(record.questions) != 1: - # self._info(f"Received len(questions_ != 1 ({record.questions})") - # To simplify things we only respond if there is 1 question. - # This is apparently common amongst DNS server implementations. - # For more information see the responses to this SO question: - # https://stackoverflow.com/q/4082081 - response.header.rcode = dnslib.RCODE.REFUSED - return response - - try: - query = Query.from_dns_question(record.questions[0]) - except ValueError: - # self._warning(e) - response.header.rcode = dnslib.RCODE.FORMERR - return response - - result = self.query_middleware(query) - - response.add_answer(*result.get_answer_records()) - response.add_ar(*result.get_additional_records()) - response.add_auth(*result.get_authority_records()) - response.header.rcode = result.error_code - return response +### TYPE_CHECKING +### ============================================================================ +if TYPE_CHECKING and False: # pylint: disable=condition-evals-to-constant + # pylint: disable=undefined-variable + q1 = QueryExceptionHandlerMiddleware() + reveal_type(q1) + reveal_type(q1.handlers) + reveal_type(q1.default_handler) + r1 = RawExceptionHandlerMiddleware() + reveal_type(r1) + reveal_type(r1.handlers) + reveal_type(r1.default_handler) diff --git a/src/nserver/models.py b/src/nserver/models.py index a4626a3..dee7117 100644 --- a/src/nserver/models.py +++ b/src/nserver/models.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from typing import Optional, Union, List @@ -37,7 +40,7 @@ def __init__(self, qtype: str, name: str) -> None: return @classmethod - def from_dns_question(cls, question: dnslib.DNSQuestion) -> "Query": + def from_dns_question(cls, question: dnslib.DNSQuestion) -> Query: """Create a new query from a `dnslib.DNSQuestion`""" if question.qtype not in dnslib.QTYPE.forward: raise ValueError(f"Invalid QTYPE: {question.qtype}") @@ -106,14 +109,14 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def get_answer_records(self) -> List[dnslib.RD]: + def get_answer_records(self) -> list[dnslib.RD]: """Prepare resource records for answer section""" return [record.to_resource_record() for record in self.answers] - def get_additional_records(self) -> List[dnslib.RD]: + def get_additional_records(self) -> list[dnslib.RD]: """Prepare resource records for additional section""" return [record.to_resource_record() for record in self.additional] - def get_authority_records(self) -> List[dnslib.RD]: + def get_authority_records(self) -> list[dnslib.RD]: """Prepare resource records for authority section""" return [record.to_resource_record() for record in self.authority] diff --git a/src/nserver/records.py b/src/nserver/records.py index 9915f67..3510409 100644 --- a/src/nserver/records.py +++ b/src/nserver/records.py @@ -2,10 +2,13 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from ipaddress import IPv4Address, IPv6Address import re -from typing import Any, Union, Dict +from typing import Any ## Installed import dnslib @@ -36,7 +39,7 @@ def __init__(self, resource_name: str, ttl: int) -> None: type_name = self.__class__.__name__ self._qtype = getattr(dnslib.QTYPE, type_name) self._class = getattr(dnslib, type_name) # class means python class not RR CLASS - self._record_kwargs: Dict[str, Any] + self._record_kwargs: dict[str, Any] is_unsigned_int_size(ttl, 32, throw_error=True, value_name="ttl") self.ttl = ttl self.resource_name = resource_name @@ -56,7 +59,7 @@ def to_resource_record(self) -> dnslib.RR: class A(RecordBase): # pylint: disable=invalid-name """Ipv4 Address (`A`) Record.""" - def __init__(self, resource_name: str, ip: Union[str, IPv4Address], ttl: int = 300) -> None: + def __init__(self, resource_name: str, ip: str | IPv4Address, ttl: int = 300) -> None: """ Args: resource_name: DNS resource name @@ -77,7 +80,7 @@ def __init__(self, resource_name: str, ip: Union[str, IPv4Address], ttl: int = 3 class AAAA(RecordBase): """Ipv6 Address (`AAAA`) Record.""" - def __init__(self, resource_name: str, ip: Union[str, IPv6Address], ttl: int = 300) -> None: + def __init__(self, resource_name: str, ip: str | IPv6Address, ttl: int = 300) -> None: """ Args: resource_name: DNS resource name @@ -222,7 +225,7 @@ class SOA(RecordBase): - https://en.wikipedia.org/wiki/SOA_record """ - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, zone_name: str, primary_name_server: str, diff --git a/src/nserver/rules.py b/src/nserver/rules.py index 3abf5bc..a3b111d 100644 --- a/src/nserver/rules.py +++ b/src/nserver/rules.py @@ -2,9 +2,12 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import re -from typing import Callable, List, Optional, Pattern, Union, Type +from typing import Callable, Pattern, Union, Type, List ## Installed import dnslib @@ -16,7 +19,7 @@ ### CONSTANTS ### ============================================================================ -ALL_QTYPES: List[str] = list(dnslib.QTYPE.reverse.keys()) +ALL_QTYPES: list[str] = list(dnslib.QTYPE.reverse.keys()) """All supported Query Types New in `2.0`. @@ -27,7 +30,33 @@ ### FUNCTIONS ### ============================================================================ -def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs) -> "RuleBase": +def coerce_to_response(result: RuleResult) -> Response: + """Convert some `RuleResult` to a `Response` + + Args: + result: the results to convert + + Raises: + TypeError: unsupported result type + + New in `3.0`. + """ + if isinstance(result, Response): + return result + + if result is None: + return Response() + + if isinstance(result, RecordBase) and result.__class__ is not RecordBase: + return Response(answers=result) + + if isinstance(result, list) and all(isinstance(item, RecordBase) for item in result): + return Response(answers=result) + + raise TypeError(f"Cannot process result: {result!r}") + + +def smart_make_rule(rule: Union[Type[RuleBase], str, Pattern], *args, **kwargs) -> RuleBase: """Create a rule using shorthand notation. The exact type of rule returned depends on what is povided by `rule`. @@ -76,7 +105,7 @@ def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs class RuleBase: """Base class for all Rules to inherit from.""" - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """From the given query return the function to run, if any. If no function should be run (i.e. because it does not match the rule), @@ -99,7 +128,7 @@ class StaticRule(RuleBase): def __init__( self, match_string: str, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -116,7 +145,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None @@ -147,7 +176,7 @@ class ZoneRule(RuleBase): def __init__( self, zone: str, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -165,7 +194,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if self.allowed_qtypes is not None and query.type not in self.allowed_qtypes: return None @@ -194,7 +223,7 @@ class RegexRule(RuleBase): def __init__( self, regex: Pattern, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -219,7 +248,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None @@ -257,7 +286,7 @@ class WildcardStringRule(RuleBase): def __init__( self, wildcard_string: str, - allowed_qtypes: List, + allowed_qtypes: list, func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -274,7 +303,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None diff --git a/src/nserver/server.py b/src/nserver/server.py index 8c69b61..480d896 100644 --- a/src/nserver/server.py +++ b/src/nserver/server.py @@ -1,164 +1,158 @@ ### IMPORTS ### ============================================================================ -## Standard Library -import logging +## Future +from __future__ import annotations -# Note: Optional can only be replaced with `| None` in 3.10+ -from typing import List, Dict, Optional, Union, Type, Pattern +## Standard Library +from typing import TypeVar, Generic, Pattern ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application -from .exceptions import InvalidMessageError from .models import Query, Response -from .rules import smart_make_rule, RuleBase, ResponseFunction -from .settings import Settings -from .transport import TransportBase, UDPv4Transport, UDPv6Transport, TCPv4Transport +from .rules import coerce_to_response, smart_make_rule, RuleBase, ResponseFunction -from . import middleware +from . import middleware as m ### CONSTANTS ### ============================================================================ -TRANSPORT_MAP: Dict[str, Type[TransportBase]] = { - "UDPv4": UDPv4Transport, - "UDPv6": UDPv6Transport, - "TCPv4": TCPv4Transport, -} +# pylint: disable=invalid-name +T_middleware = TypeVar("T_middleware", bound=m.MiddlewareBase) +T_exception_handler = TypeVar("T_exception_handler", bound=m.ExceptionHandlerBase) +# pylint: enable=invalid-name ### Classes ### ============================================================================ -class Scaffold: - """Base class for shared functionality between `NameServer` and `Blueprint` +class MiddlewareMixin(Generic[T_middleware, T_exception_handler]): + """Generic mixin for building a middleware stack in a server. - New in `2.0`. + Should not be used directly, instead use the servers that implement it: + `NameServer`, `RawNameServer`. - Attributes: - rules: registered rules - hook_middleware: hook middleware - exception_handler_middleware: Query exception handler middleware + New in `3.0`. """ - _logger: logging.Logger - - def __init__(self, name: str) -> None: - """ - Args: - name: The name of the server. This is used for internal logging. - """ - self.name = name - - self.rules: List[RuleBase] = [] - self.hook_middleware = middleware.HookMiddleware() - self.exception_handler_middleware = middleware.ExceptionHandlerMiddleware() + _exception_handler: T_exception_handler - self._user_query_middleware: List[middleware.QueryMiddleware] = [] - self._query_middleware_stack: List[ - Union[middleware.QueryMiddleware, middleware.QueryMiddlewareCallable] - ] = [] + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._middleware_stack_final: list[T_middleware] | None = None + self._middleware_stack_user: list[T_middleware] = [] return - ## Register Methods + ## Middleware ## ------------------------------------------------------------------------- - def register_rule(self, rule: RuleBase) -> None: - """Register the given rule + def middleware_is_prepared(self) -> bool: + """Check if the middleware has been prepared.""" + return self._middleware_stack_final is not None + + def append_middleware(self, middleware: T_middleware) -> None: + """Append this middleware to the middleware stack Args: - rule: the rule to register + middleware: middleware to append """ - self._debug(f"Registered rule: {rule!r}") - self.rules.append(rule) + if self.middleware_is_prepared(): + raise RuntimeError("Cannot append middleware once prepared") + self._middleware_stack_user.append(middleware) return - def register_blueprint( - self, blueprint: "Blueprint", rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs - ) -> None: - """Register a blueprint using [`smart_make_rule`][nserver.rules.smart_make_rule]. + def prepare_middleware(self) -> None: + """Prepare middleware for consumption - New in `2.0`. + Child classes should wrap this method to set the `next_function` on the + final middleware in the stack. + """ + if self.middleware_is_prepared(): + raise RuntimeError("Middleware is already prepared") - Args: - blueprint: the `Blueprint` to attach - rule_: rule as per `nserver.rules.smart_make_rule` - args: extra arguments to provide `smart_make_rule` - kwargs: extra keyword arguments to provide `smart_make_rule` + middleware_stack = self._prepare_middleware_stack() - Raises: - ValueError: if `func` is provided in `kwargs`. - """ + next_middleware: T_middleware | None = None - if "func" in kwargs: - raise ValueError("Must not provide `func` in kwargs") - self.register_rule(smart_make_rule(rule_, *args, func=blueprint.entrypoint, **kwargs)) + for middleware in middleware_stack[::-1]: + if next_middleware is not None: + middleware.set_next_function(next_middleware) + next_middleware = middleware + + self._middleware_stack_final = middleware_stack return - def register_before_first_query(self, func: middleware.BeforeFirstQueryHook) -> None: - """Register a function to be run before the first query. + def _prepare_middleware_stack(self) -> list[T_middleware]: + """Create final stack of middleware. - Args: - func: the function to register + Child classes may override this method to customise the final middleware stack. """ - self.hook_middleware.before_first_query.append(func) - return + return [self._exception_handler, *self._middleware_stack_user] # type: ignore[list-item] - def register_before_query(self, func: middleware.BeforeQueryHook) -> None: - """Register a function to be run before every query. + @property + def middleware(self) -> list[T_middleware]: + """Accssor for this servers middleware. - Args: - func: the function to register - If `func` returns anything other than `None` will stop processing the - incoming `Query` and continue to result processing with the return value. + If the server has been prepared then returns a copy of the prepared middleware. + Otherwise returns a mutable list of the registered middleware. """ - self.hook_middleware.before_query.append(func) + if self.middleware_is_prepared(): + return self._middleware_stack_final.copy() # type: ignore[union-attr] + return self._middleware_stack_user + + ## Exception Handler + ## ------------------------------------------------------------------------- + def register_exception_handler(self, *args, **kwargs) -> None: + """Shortcut for `self.exception_handler.set_handler`""" + self.exception_handler_middleware.set_handler(*args, **kwargs) return - def register_after_query(self, func: middleware.AfterQueryHook) -> None: - """Register a function to be run on the result of a query. + @property + def exception_handler_middleware(self) -> T_exception_handler: + """Read only accessor for this server's middleware exception handler""" + return self._exception_handler + + def exception_handler(self, exception_class: type[Exception]): + """Decorator for registering a function as an raw exception handler Args: - func: the function to register + exception_class: The `Exception` class to register this handler for """ - self.hook_middleware.after_query.append(func) - return - def register_middleware(self, query_middleware: middleware.QueryMiddleware) -> None: - """Add a `QueryMiddleware` to this server. + def decorator(func): + nonlocal exception_class + self.register_raw_exception_handler(exception_class, func) + return func - New in `2.0`. + return decorator - Args: - query_middleware: the middleware to add - """ - if self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Cannot register middleware after stack is created") - self._user_query_middleware.append(query_middleware) - return - def register_exception_handler( - self, exception_class: Type[Exception], handler: middleware.ExceptionHandler - ) -> None: - """Register an exception handler for the `QueryMiddleware` +## Mixins +## ----------------------------------------------------------------------------- +class RulesMixin(LoggingMixin): + """Base class for rules based functionality` + + Attributes: + rules: reistered rules - Only one handler can exist for a given exception type. + New in `3.0`. + """ - New in `2.0`. + def __init__(self) -> None: + super().__init__() + self.rules: list[RuleBase] = [] + return + + def register_rule(self, rule: RuleBase) -> None: + """Register the given rule Args: - exception_class: the type of exception to handle - handler: the function to call when handling an exception + rule: the rule to register """ - if exception_class in self.exception_handler_middleware.exception_handlers: - raise ValueError("Exception handler already exists for {exception_class}") - - self.exception_handler_middleware.exception_handlers[exception_class] = handler + self.vdebug(f"Registered rule: {rule!r}") + self.rules.append(rule) return - # Decorators - # .......................................................................... - def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): + def rule(self, rule_: type[RuleBase] | str | Pattern, *args, **kwargs): """Decorator for registering a function using [`smart_make_rule`][nserver.rules.smart_make_rule]. Changed in `2.0`: This method now uses `smart_make_rule`. @@ -184,342 +178,241 @@ def decorator(func: ResponseFunction): return decorator - def before_first_query(self): - """Decorator for registering before_first_query hook. - - These functions are called when the server receives it's first query, but - before any further processesing. - """ - - def decorator(func: middleware.BeforeFirstQueryHook): - self.register_before_first_query(func) - return func - return decorator +## Servers +## ----------------------------------------------------------------------------- +class RawNameServer( + MiddlewareMixin[m.RawMiddleware, m.RawExceptionHandlerMiddleware], LoggingMixin +): + """Server that handles raw `dnslib.DNSRecord` queries. - def before_query(self): - """Decorator for registering before_query hook. + This allows interacting with the underlying DNS messages from our dns library. + As such this server is implementation dependent and may change from time to time. - These functions are called before processing each query. - """ + In general you should use `NameServer` as it is implementation independent. - def decorator(func: middleware.BeforeQueryHook): - self.register_before_query(func) - return func + New in `3.0`. + """ - return decorator + def __init__(self, nameserver: NameServer) -> None: + self._exception_handler = m.RawExceptionHandlerMiddleware() + super().__init__() + self.nameserver: NameServer = nameserver + self.logger = self.get_logger() + return - def after_query(self): - """Decorator for registering after_query hook. + def process_request(self, request: m.RawRecord) -> m.RawRecord: + """Process a request using this server. - These functions are after the rule function is run and may modify the - response. + This will pass the request through the middleware stack. """ + if not self.middleware_is_prepared(): + self.prepare_middleware() + return self.middleware[0](request) - def decorator(func: middleware.AfterQueryHook): - self.register_after_query(func) - return func + def send_request_to_nameserver(self, record: m.RawRecord) -> m.RawRecord: + """Send a request to the `NameServer` of this instance. - return decorator - - def exception_handler(self, exception_class: Type[Exception]): - """Decorator for registering a function as an exception handler - - New in `2.0`. - - Args: - exception_class: The `Exception` class to register this handler for + Although this is the final step after passing a request through all middleware, + it can be called directly to avoid using middleware such as when testing. """ + response = record.reply() + + if record.header.opcode != dnslib.OPCODE.QUERY: + self.debug(f"Received non-query opcode: {record.header.opcode}") + # This server only response to DNS queries + response.header.rcode = dnslib.RCODE.NOTIMP + return response + + if len(record.questions) != 1: + self.debug(f"Received len(questions_ != 1 ({record.questions})") + # To simplify things we only respond if there is 1 question. + # This is apparently common amongst DNS server implementations. + # For more information see the responses to this SO question: + # https://stackoverflow.com/q/4082081 + response.header.rcode = dnslib.RCODE.REFUSED + return response - def decorator(func: middleware.ExceptionHandler): - nonlocal exception_class - self.register_exception_handler(exception_class, func) - return func - - return decorator - - ## Internal Functions - ## ------------------------------------------------------------------------- - def _prepare_query_middleware_stack(self) -> None: - """Prepare the `QueryMiddleware` for this server.""" - if self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("QueryMiddleware stack already exists") - - middleware_stack: List[middleware.QueryMiddleware] = [ - self.exception_handler_middleware, - *self._user_query_middleware, - self.hook_middleware, - ] - rule_processor = middleware.RuleProcessor(self.rules) - - next_middleware: Optional[middleware.QueryMiddleware] = None - for query_middleware in middleware_stack[::-1]: - if next_middleware is None: - query_middleware.register_next_function(rule_processor) - else: - query_middleware.register_next_function(next_middleware) - next_middleware = query_middleware - - self._query_middleware_stack.extend(middleware_stack) - self._query_middleware_stack.append(rule_processor) + try: + query = Query.from_dns_question(record.questions[0]) + except ValueError: + # TODO: should we embed raw DNS query? Maybe this should be configurable. + self.warning("Failed to parse Query from request", exc_info=True) + response.header.rcode = dnslib.RCODE.FORMERR + return response + + result = self.nameserver.process_request(query) + + response.add_answer(*result.get_answer_records()) + response.add_ar(*result.get_additional_records()) + response.add_auth(*result.get_authority_records()) + response.header.rcode = result.error_code + return response + + def prepare_middleware(self) -> None: + super().prepare_middleware() + self.middleware[-1].set_next_function(self.send_request_to_nameserver) return - ## Logging - ## ------------------------------------------------------------------------- - def _vvdebug(self, *args, **kwargs): - """Log very verbose debug message.""" - - return self._logger.log(6, *args, **kwargs) - - def _vdebug(self, *args, **kwargs): - """Log verbose debug message.""" - - return self._logger.log(8, *args, **kwargs) - - def _debug(self, *args, **kwargs): - """Log debug message.""" - return self._logger.debug(*args, **kwargs) +class NameServer( + MiddlewareMixin[m.QueryMiddleware, m.QueryExceptionHandlerMiddleware], RulesMixin, LoggingMixin +): + """High level DNS Name Server for responding to DNS queries. - def _info(self, *args, **kwargs): - """Log very verbose debug message.""" + *Changed in `3.0`*: - return self._logger.info(*args, **kwargs) + - "Raw" functionality removed and moved to `RawNameServer`. + - "Transport" and "Application" functionality removed. + """ - def _warning(self, *args, **kwargs): - """Log warning message.""" + def __init__(self, name: str) -> None: + """ + Args: + name: The name of the server. This is used for internal logging. + """ + self.name = name + self._exception_handler = m.QueryExceptionHandlerMiddleware() + super().__init__() + self.hooks = m.HookMiddleware() + self.logger = self.get_logger() + return - return self._logger.warning(*args, **kwargs) + def _prepare_middleware_stack(self) -> list[m.QueryMiddleware]: + stack = super()._prepare_middleware_stack() + stack.append(self.hooks) + return stack - def _error(self, *args, **kwargs): - """Log an error message.""" + ## Register Methods + ## ------------------------------------------------------------------------- + def register_subserver( + self, nameserver: NameServer, rule_: type[RuleBase] | str | Pattern, *args, **kwargs + ) -> None: + """Register a `NameServer` using [`smart_make_rule`][nserver.rules.smart_make_rule]. - return self._logger.error(*args, **kwargs) + This allows for composing larger applications. - def _critical(self, *args, **kwargs): - """Log a critical message.""" + Args: + subserver: the `SubServer` to attach + rule_: rule as per `nserver.rules.smart_make_rule` + args: extra arguments to provide `smart_make_rule` + kwargs: extra keyword arguments to provide `smart_make_rule` - return self._logger.critical(*args, **kwargs) + Raises: + ValueError: if `func` is provided in `kwargs`. + New in `3.0`. + """ -class NameServer(Scaffold): - """NameServer for responding to requests.""" + if "func" in kwargs: + raise ValueError("Must not provide `func` in kwargs") + self.register_rule(smart_make_rule(rule_, *args, func=nameserver.process_request, **kwargs)) + return - # pylint: disable=too-many-instance-attributes + def register_before_first_query(self, func: m.BeforeFirstQueryHook) -> None: + """Register a function to be run before the first query. - def __init__(self, name: str, settings: Optional[Settings] = None) -> None: - """ Args: - name: The name of the server. This is used for internal logging. - settings: settings to use with this `NameServer` instance + func: the function to register """ - super().__init__(name) - self._logger = logging.getLogger(f"nserver.i.{self.name}") - - self.raw_exception_handler_middleware = middleware.RawRecordExceptionHandlerMiddleware() - self._user_raw_record_middleware: List[middleware.RawRecordMiddleware] = [] - self._raw_record_middleware_stack: List[ - Union[middleware.RawRecordMiddleware, middleware.RawRecordMiddlewareCallable] - ] = [] - - self.settings = settings if settings is not None else Settings() - - transport = TRANSPORT_MAP.get(self.settings.server_transport) - if transport is None: - raise ValueError( - f"Invalid settings.server_transport {self.settings.server_transport!r}" - ) - self.transport = transport(self.settings) - - self.shutdown_server = False - self.exit_code = 0 + self.hooks.before_first_query.append(func) return - ## Register Methods - ## ------------------------------------------------------------------------- - def register_raw_middleware(self, raw_middleware: middleware.RawRecordMiddleware) -> None: - """Add a `RawRecordMiddleware` to this server. - - New in `2.0`. + def register_before_query(self, func: m.BeforeQueryHook) -> None: + """Register a function to be run before every query. Args: - raw_middleware: the middleware to add + func: the function to register + If `func` returns anything other than `None` will stop processing the + incoming `Query` and continue to result processing with the return value. """ - if self._raw_record_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Cannot register middleware after stack is created") - self._user_raw_record_middleware.append(raw_middleware) + self.hooks.before_query.append(func) return - def register_raw_exception_handler( - self, exception_class: Type[Exception], handler: middleware.RawRecordExceptionHandler - ) -> None: - """Register a raw exception handler for the `RawRecordMiddleware`. - - Only one handler can exist for a given exception type. - - New in `2.0`. + def register_after_query(self, func: m.AfterQueryHook) -> None: + """Register a function to be run on the result of a query. Args: - exception_class: the type of exception to handle - handler: the function to call when handling an exception + func: the function to register """ - if exception_class in self.raw_exception_handler_middleware.exception_handlers: - raise ValueError("Exception handler already exists for {exception_class}") - - self.raw_exception_handler_middleware.exception_handlers[exception_class] = handler + self.hooks.after_query.append(func) return # Decorators # .......................................................................... - def raw_exception_handler(self, exception_class: Type[Exception]): - """Decorator for registering a function as an raw exception handler - - New in `2.0`. + def before_first_query(self): + """Decorator for registering before_first_query hook. - Args: - exception_class: The `Exception` class to register this handler for + These functions are called when the server receives it's first query, but + before any further processesing. """ - def decorator(func: middleware.RawRecordExceptionHandler): - nonlocal exception_class - self.register_raw_exception_handler(exception_class, func) + def decorator(func: m.BeforeFirstQueryHook): + self.register_before_first_query(func) return func return decorator - ## Public Methods - ## ------------------------------------------------------------------------- - def run(self) -> int: - """Start running the server + def before_query(self): + """Decorator for registering before_query hook. - Returns: - `exit_code`, `0` if exited normally + These functions are called before processing each query. """ - # Setup Logging - console_logger = logging.StreamHandler() - console_logger.setLevel(self.settings.console_log_level) - console_formatter = logging.Formatter( - "[{asctime}][{levelname}][{name}] {message}", style="{" - ) + def decorator(func: m.BeforeQueryHook): + self.register_before_query(func) + return func - console_logger.setFormatter(console_formatter) + return decorator - self._logger.addHandler(console_logger) - self._logger.setLevel(min(self.settings.console_log_level, self.settings.file_log_level)) + def after_query(self): + """Decorator for registering after_query hook. - # Start Server - # TODO: Do we want to recreate the transport instance or do we assume that - # transport.shutdown_server puts it back into a ready state? - # We could make this configurable? :thonking: + These functions are after the rule function is run and may modify the + response. + """ - self._info(f"Starting {self.transport}") - try: - self._prepare_middleware_stacks() - self.transport.start_server() - except Exception as e: # pylint: disable=broad-except - self._critical(e) - self.exit_code = 1 - return self.exit_code - - # Process Requests - error_count = 0 - while True: - if self.shutdown_server: - break - try: - message = self.transport.receive_message() - response = self._process_dns_record(message.message) - message.response = response - self.transport.send_message_response(message) - except InvalidMessageError as e: - self._warning(f"{e}") - except Exception as e: # pylint: disable=broad-except - self._error(f"Uncaught error occured. {e}", exc_info=True) - error_count += 1 - if error_count >= self.settings.max_errors: - self._critical(f"Max errors hit ({error_count})") - self.shutdown_server = True - self.exit_code = 1 - except KeyboardInterrupt: - self._info("KeyboardInterrupt received.") - self.shutdown_server = True - - # Stop Server - self._info("Shutting down server") - self.transport.stop_server() - - # Teardown Logging - self._logger.removeHandler(console_logger) - return self.exit_code + def decorator(func: m.AfterQueryHook): + self.register_after_query(func) + return func + + return decorator ## Internal Functions ## ------------------------------------------------------------------------- - def _process_dns_record(self, message: dnslib.DNSRecord) -> dnslib.DNSRecord: - """Process the given DNSRecord by sending it into the `RawRecordMiddleware` stack. + def process_request(self, query: Query) -> Response: + """Process a query passing it through all middleware.""" + if not self.middleware_is_prepared(): + self.prepare_middleware() + return self.middleware[0](query) + + def prepare_middleware(self) -> None: + super().prepare_middleware() + self.middleware[-1].set_next_function(self.send_query_to_rules) + return - Args: - message: the DNS query to process + def send_query_to_rules(self, query: Query) -> Response: + """Send a query to be processed by the rules of this instance. - Returns: - the DNS response + Although intended to be the final step after passing a query through all middleware, + this method can be used to bypass the middleware of this server such as for testing. """ - if self._raw_record_middleware_stack is None: - raise RuntimeError( - "RawRecordMiddleware stack does not exist. Have you called _prepare_middleware?" - ) - return self._raw_record_middleware_stack[0](message) - - def _prepare_middleware_stacks(self) -> None: - """Prepare all middleware for this server.""" - self._prepare_query_middleware_stack() - self._prepare_raw_record_middleware_stack() - return + for rule in self.rules: + rule_func = rule.get_func(query) + if rule_func is not None: + self.debug(f"Matched Rule: {rule}") + return coerce_to_response(rule_func(query)) - def _prepare_raw_record_middleware_stack(self) -> None: - """Prepare the `RawRecordMiddleware` for this server.""" - if not self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Must prepare QueryMiddleware stack first") - - if self._raw_record_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("RawRecordMiddleware stack already exists") - - middleware_stack: List[middleware.RawRecordMiddleware] = [ - self.raw_exception_handler_middleware, - *self._user_raw_record_middleware, - ] - - query_middleware_processor = middleware.QueryMiddlewareProcessor( - self._query_middleware_stack[0] - ) - - next_middleware: Optional[middleware.RawRecordMiddleware] = None - for raw_middleware in middleware_stack[::-1]: - if next_middleware is None: - raw_middleware.register_next_function(query_middleware_processor) - else: - raw_middleware.register_next_function(next_middleware) - next_middleware = raw_middleware - - self._raw_record_middleware_stack.extend(middleware_stack) - self._raw_record_middleware_stack.append(query_middleware_processor) - return + self.debug("Did not match any rule") + return Response(error_code=dnslib.RCODE.NXDOMAIN) -class Blueprint(Scaffold): - """Class that can replicate many of the functions of a `NameServer`. +class Blueprint(RulesMixin, RuleBase, LoggingMixin): + """A container for rules that can be registered onto a server - They can be used to construct or extend applications. + It can be registered as normal rule: `server.register_rule(blueprint_rule)` - New in `2.0`. + New in `3.0`. """ def __init__(self, name: str) -> None: @@ -527,15 +420,16 @@ def __init__(self, name: str) -> None: Args: name: The name of the server. This is used for internal logging. """ - super().__init__(name) - self._logger = logging.getLogger(f"nserver.b.{self.name}") + super().__init__() + self.name = name + self.logger = self.get_logger() return - def entrypoint(self, query: Query) -> Response: - """Entrypoint into this `Blueprint`. - - This method should be passed to rules as the function to run. - """ - if not self._query_middleware_stack: - self._prepare_query_middleware_stack() - return self._query_middleware_stack[0](query) + def get_func(self, query: Query) -> ResponseFunction | None: + for rule in self.rules: + func = rule.get_func(query) + if func is not None: + self.debug(f"matched {rule}") + return func + self.debug("did not match any rule") + return None diff --git a/src/nserver/settings.py b/src/nserver/settings.py deleted file mode 100644 index bd439cf..0000000 --- a/src/nserver/settings.py +++ /dev/null @@ -1,35 +0,0 @@ -### IMPORTS -### ============================================================================ -## Standard Library -from dataclasses import dataclass -import logging - -## Installed - -## Application - - -### CLASSES -### ============================================================================ -@dataclass -class Settings: - """Dataclass for NameServer settings - - Attributes: - server_transport: What `Transport` to use. See `nserver.server.TRANSPORT_MAP` for options. - server_address: What address `server_transport` will bind to. - server_port: what port `server_port` will bind to. - """ - - server_transport: str = "UDPv4" - server_address: str = "localhost" - server_port: int = 9953 - console_log_level: int = logging.INFO - file_log_level: int = logging.INFO - max_errors: int = 5 - - # Not implemented, ideas for useful things - # debug: bool = False # Put server into "debug mode" (e.g. hot reload) - # health_check: bool = False # provde route for health check - # stats: bool = False # provide route for retrieving operational stats - # remote_admin: bool = False # allow remote shutdown restart etc? diff --git a/src/nserver/transport.py b/src/nserver/transport.py index f431a36..010a6a5 100644 --- a/src/nserver/transport.py +++ b/src/nserver/transport.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from collections import deque from dataclasses import dataclass @@ -8,16 +11,15 @@ import socket import struct import time +from typing import Deque, NewType, Any, cast -# Note: Union can only be replaced with `X | Y` in 3.10+ -from typing import Tuple, Optional, Dict, List, Deque, NewType, Any, Union, cast ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application from .exceptions import InvalidMessageError -from .settings import Settings ### CONSTANTS @@ -48,7 +50,7 @@ class TcpState(enum.IntEnum): ### FUNCTIONS ### ============================================================================ -def get_tcp_info(connection: socket.socket) -> Tuple: +def get_tcp_info(connection: socket.socket) -> tuple: """Get `socket.TCP_INFO` from socket Args: @@ -111,9 +113,9 @@ class MessageContainer: # pylint: disable=too-few-public-methods def __init__( self, raw_data: bytes, - transport: "TransportBase", + transport: TransportBase, transport_data: Any, - remote_client: Union[str, Tuple[str, int]], + remote_client: str | tuple[str, int], ): """Create new message container @@ -148,7 +150,7 @@ def __init__( self.transport = transport self.transport_data = transport_data self.remote_client = remote_client - self.response: Optional[dnslib.DNSRecord] = None + self.response: dnslib.DNSRecord | None = None return def get_response_bytes(self): @@ -160,16 +162,11 @@ def get_response_bytes(self): ## Transport Classes ## ----------------------------------------------------------------------------- -class TransportBase: +class TransportBase(LoggingMixin): """Base class for all transports""" - def __init__(self, settings: Settings) -> None: - """ - Args: - settings: settings of the server this transport is attached to - """ - self.settings = settings - # TODO: setup logging + def __init__(self) -> None: + self.logger = self.get_logger() return def start_server(self, timeout: int = 60) -> None: @@ -199,7 +196,7 @@ class UDPMessageData: remote_address: UDP peername that this message was received from """ - remote_address: Tuple[str, int] + remote_address: tuple[str, int] class UDPv4Transport(TransportBase): @@ -207,10 +204,10 @@ class UDPv4Transport(TransportBase): _SOCKET_AF = socket.AF_INET - def __init__(self, settings: Settings): - super().__init__(settings) - self.address = self.settings.server_address - self.port = self.settings.server_port + def __init__(self, address: str, port: int): + super().__init__() + self.address = address + self.port = port self.socket = socket.socket(self._SOCKET_AF, socket.SOCK_DGRAM) return @@ -284,7 +281,7 @@ class CachedConnection: """ connection: socket.socket - remote_address: Tuple[str, int] + remote_address: tuple[str, int] last_data_time: float selector_key: selectors.SelectorKey cache_key: CacheKey @@ -306,17 +303,17 @@ class TCPv4Transport(TransportBase): CONNECTION_CACHE_TARGET = int(CONNECTION_CACHE_LIMIT * CONNECTION_CACHE_VACUUM_PERCENT) CONNECTION_CACHE_CLEAN_INTERVAL = 10 # seconds - def __init__(self, settings: Settings) -> None: - super().__init__(settings) - self.address = self.settings.server_address - self.port = self.settings.server_port + def __init__(self, address: str, port: int) -> None: + super().__init__() + self.address = address + self.port = port self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setblocking(False) # Allow taking over of socket when in TIME_WAIT (i.e. previously released) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.selector = selectors.DefaultSelector() - self.cached_connections: Dict[CacheKey, CachedConnection] = {} + self.cached_connections: dict[CacheKey, CachedConnection] = {} self.last_cache_clean = 0.0 self.connection_queue: Deque[socket.socket] = deque() @@ -380,8 +377,41 @@ def stop_server(self) -> None: def __repr__(self): return f"{self.__class__.__name__}(address={self.address!r}, port={self.port!r})" - def _get_next_connection(self) -> Tuple[socket.socket, Tuple[str, int]]: - """Get the next connection that is ready to receive data on.""" + def _get_next_connection(self) -> tuple[socket.socket, tuple[str, int]]: + """Get the next connection that is ready to receive data on. + + Blocks until a good connection is found + """ + while True: + if not self.connection_queue: + self._populate_connection_queue() + + # There is something in the queue - attempt to get it + connection = self.connection_queue.popleft() + + if not self._connection_viable(connection): + self._remove_connection(connection) + continue + + # Connection is probably viable + try: + remote_address = connection.getpeername() + except OSError as e: + if e.errno == 107: # Transport endpoint is not connected + self._remove_connection(connection) + continue + + raise # Unknown OSError - raise it. + + break # we have a valid connection + + return connection, remote_address + + def _populate_connection_queue(self) -> None: + """Populate self.connection_queue + + Blocks until there is at least on connection + """ while not self.connection_queue: # loop until connection is ready for execution events = self.selector.select(self.SELECT_TIMEOUT) @@ -413,13 +443,7 @@ def _get_next_connection(self) -> Tuple[socket.socket, Tuple[str, int]]: # No connections ready, take advantage to do cleanup elif time.time() - self.last_cache_clean > self.CONNECTION_CACHE_CLEAN_INTERVAL: self._cleanup_cached_connections() - - # We have a connection in the queue - # print(f"connection_queue: {self.connection_queue}") - connection = self.connection_queue.popleft() - remote_address = connection.getpeername() - - return connection, remote_address + return def _accept_connection(self) -> socket.socket: """Accept a connection, cache it, and add it to the selector""" @@ -471,7 +495,7 @@ def _connection_viable(connection: socket.socket) -> bool: def _cleanup_cached_connections(self) -> None: "Cleanup cached connections" now = time.time() - cache_clear: List[CacheKey] = [] + cache_clear: list[CacheKey] = [] for cache_key, cache in self.cached_connections.items(): if now - cache.last_data_time > self.CONNECTION_KEEPALIVE_LIMIT: if cache.connection not in self.connection_queue: @@ -485,7 +509,7 @@ def _cleanup_cached_connections(self) -> None: for cache_key in cache_clear: self._remove_connection(cache_key=cache_key) - quiet_connections: List[CachedConnection] = [] + quiet_connections: list[CachedConnection] = [] cached_connections_len = len(self.cached_connections) cache_clear = [] @@ -516,7 +540,7 @@ def _cleanup_cached_connections(self) -> None: return def _remove_connection( - self, connection: Optional[socket.socket] = None, cache_key: Optional[CacheKey] = None + self, connection: socket.socket | None = None, cache_key: CacheKey | None = None ) -> None: """Remove a connection from the server (closing it in the process) diff --git a/src/nserver/util.py b/src/nserver/util.py index 261add3..191fad7 100644 --- a/src/nserver/util.py +++ b/src/nserver/util.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library ## Installed diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index e801dc8..30a6fa6 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -3,15 +3,11 @@ ### IMPORTS ### ============================================================================ ## Standard Library -from typing import no_type_check, List -import unittest.mock - ## Installed import dnslib import pytest -from nserver import NameServer, Blueprint, Query, Response, ALL_QTYPES, ZoneRule, A -from nserver.server import Scaffold +from nserver import NameServer, RawNameServer, Blueprint, Query, A ## Application @@ -22,6 +18,7 @@ blueprint_1 = Blueprint("blueprint_1") blueprint_2 = Blueprint("blueprint_2") blueprint_3 = Blueprint("blueprint_3") +raw_server = RawNameServer(server) ## Rules @@ -34,100 +31,11 @@ def dummy_rule(query: Query) -> A: return A(query.name, IP) -## Hooks -## ----------------------------------------------------------------------------- -def register_hooks(scaff: Scaffold) -> None: - scaff.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) - scaff.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) - scaff.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) - return - - -@no_type_check -def reset_hooks(scaff: Scaffold) -> None: - scaff.hook_middleware.before_first_query_run = False - scaff.hook_middleware.before_first_query[0].reset_mock() - scaff.hook_middleware.before_query[0].reset_mock() - scaff.hook_middleware.after_query[0].reset_mock() - return - - -def reset_all_hooks() -> None: - reset_hooks(server) - reset_hooks(blueprint_1) - reset_hooks(blueprint_2) - reset_hooks(blueprint_3) - return - - -@no_type_check -def check_hook_call_count(scaff: Scaffold, bfq_count: int, bq_count: int, aq_count: int) -> None: - assert scaff.hook_middleware.before_first_query[0].call_count == bfq_count - assert scaff.hook_middleware.before_query[0].call_count == bq_count - assert scaff.hook_middleware.after_query[0].call_count == aq_count - return - - -register_hooks(server) -register_hooks(blueprint_1) -register_hooks(blueprint_2) -register_hooks(blueprint_3) - - -## Exception handling -## ----------------------------------------------------------------------------- -class ErrorForTesting(Exception): - pass - - -@server.rule("throw-error.com", ["A"]) -def throw_error(query: Query) -> None: - raise ErrorForTesting() - - -def _query_error_handler(query: Query, exception: Exception) -> Response: - # pylint: disable=unused-argument - return Response(error_code=dnslib.RCODE.SERVFAIL) - - -query_error_handler = unittest.mock.MagicMock(wraps=_query_error_handler) -server.register_exception_handler(ErrorForTesting, query_error_handler) - - -class ThrowAnotherError(Exception): - pass - - -@server.rule("throw-another-error.com", ["A"]) -def throw_another_error(query: Query) -> None: - raise ThrowAnotherError() - - -def bad_error_handler(query: Query, exception: Exception) -> Response: - # pylint: disable=unused-argument - raise ErrorForTesting() - - -server.register_exception_handler(ThrowAnotherError, bad_error_handler) - - -def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> dnslib.DNSRecord: - # pylint: disable=unused-argument - response = record.reply() - response.header.rcode = dnslib.RCODE.SERVFAIL - return response - - -raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) -server.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) - ## Get server ready ## ----------------------------------------------------------------------------- -server.register_blueprint(blueprint_1, ZoneRule, "b1.com", ALL_QTYPES) -server.register_blueprint(blueprint_2, ZoneRule, "b2.com", ALL_QTYPES) -blueprint_2.register_blueprint(blueprint_3, ZoneRule, "b3.b2.com", ALL_QTYPES) - -server._prepare_middleware_stacks() +server.register_rule(blueprint_1) +server.register_rule(blueprint_2) +blueprint_2.register_rule(blueprint_3) ### TESTS @@ -136,7 +44,7 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> ## ----------------------------------------------------------------------------- @pytest.mark.parametrize("question", ["s.com", "b1.com", "b2.com", "b3.b2.com"]) def test_response(question: str): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_server.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == question @@ -145,40 +53,7 @@ def test_response(question: str): @pytest.mark.parametrize("question", ["miss.s.com", "miss.b1.com", "miss.b2.com", "miss.b3.b2.com"]) def test_nxdomain(question: str): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_server.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 0 assert response.header.rcode == dnslib.RCODE.NXDOMAIN return - - -## Hooks -## ----------------------------------------------------------------------------- -@pytest.mark.parametrize( - "question,hook_counts", - [ - ("s.com", [1, 5, 5]), - ("b1.com", [1, 5, 5, 1, 5, 5]), - ("b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5]), - ("b3.b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5, 1, 5, 5]), - ], -) -def test_hooks(question: str, hook_counts: List[int]): - ## Setup - # fill unset hook_counts - hook_counts += [0] * (12 - len(hook_counts)) - assert len(hook_counts) == 12 - # reset hooks - reset_all_hooks() - - ## Test - for _ in range(5): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) - assert len(response.rr) == 1 - assert response.rr[0].rtype == 1 - assert response.rr[0].rname == question - - check_hook_call_count(server, *hook_counts[:3]) - check_hook_call_count(blueprint_1, *hook_counts[3:6]) - check_hook_call_count(blueprint_2, *hook_counts[6:9]) - check_hook_call_count(blueprint_3, *hook_counts[9:]) - return diff --git a/tests/test_server.py b/tests/test_server.py index f8561f4..9a12db0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -10,7 +10,7 @@ import dnslib import pytest -from nserver import NameServer, Query, Response, A +from nserver import NameServer, RawNameServer, Query, Response, A ## Application @@ -18,6 +18,7 @@ ### ============================================================================ IP = "127.0.0.1" server = NameServer("tests") +raw_server = RawNameServer(server) ## Rules @@ -106,11 +107,10 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) -server.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) +raw_server.register_exception_handler(ErrorForTesting, raw_record_error_handler) ## Get server ready ## ----------------------------------------------------------------------------- -server._prepare_middleware_stacks() ### TESTS @@ -118,13 +118,13 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> ## NameServer._process_dns_record ## ----------------------------------------------------------------------------- def test_none_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("none-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("none-response.com")) assert len(response.rr) == 0 return def test_response_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("response-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("response-response.com")) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == "response-response.com." @@ -132,7 +132,7 @@ def test_response_response(): def test_record_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("record-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("record-response.com")) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == "record-response.com." @@ -140,7 +140,7 @@ def test_record_response(): def test_multi_record_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("multi-record-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("multi-record-response.com")) assert len(response.rr) == 2 for record in response.rr: assert record.rtype == 1 @@ -160,12 +160,12 @@ def test_multi_record_response(): ) def test_hook_call_count(hook, call_count): # Setup - server.hook_middleware.before_first_query_run = False + server.hooks.before_first_query_run = False hook.reset_mock() # Test for _ in range(5): - response = server._process_dns_record(dnslib.DNSRecord.question("dummy.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("dummy.com")) # Ensure respone returns and unchanged assert len(response.rr) == 1 assert response.rr[0].rtype == 1 @@ -183,7 +183,7 @@ def test_query_error_handler(): raw_record_error_handler.reset_mock() # Test - response = server._process_dns_record(dnslib.DNSRecord.question("throw-error.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("throw-error.com")) assert len(response.rr) == 0 assert response.header.get_rcode() == dnslib.RCODE.SERVFAIL @@ -199,7 +199,7 @@ def test_raw_record_error_handler(): raw_record_error_handler.reset_mock() # Test - response = server._process_dns_record(dnslib.DNSRecord.question("throw-another-error.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("throw-another-error.com")) assert len(response.rr) == 0 assert response.header.get_rcode() == dnslib.RCODE.SERVFAIL diff --git a/tests/test_subserver.py b/tests/test_subserver.py new file mode 100644 index 0000000..ee5c9b9 --- /dev/null +++ b/tests/test_subserver.py @@ -0,0 +1,174 @@ +# pylint: disable=missing-class-docstring,missing-function-docstring,protected-access + +### IMPORTS +### ============================================================================ +## Standard Library +from typing import no_type_check, List +import unittest.mock + +## Installed +import dnslib +import pytest + +from nserver import NameServer, RawNameServer, Query, Response, ALL_QTYPES, ZoneRule, A + +## Application + +### SETUP +### ============================================================================ +IP = "127.0.0.1" +nameserver = NameServer("test_subserver") +subserver_1 = NameServer("subserver_1") +subserver_2 = NameServer("subserver_2") +subserver_3 = NameServer("subserver_3") +raw_nameserver = RawNameServer(nameserver) + + +## Rules +## ----------------------------------------------------------------------------- +@nameserver.rule("s.com", ["A"]) +@subserver_1.rule("sub1.com", ["A"]) +@subserver_2.rule("sub2.com", ["A"]) +@subserver_3.rule("sub3.sub2.com", ["A"]) +def dummy_rule(query: Query) -> A: + return A(query.name, IP) + + +## Hooks +## ----------------------------------------------------------------------------- +def register_hooks(server: NameServer) -> None: + server.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) + server.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) + server.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) + return + + +@no_type_check +def reset_hooks(server: NameServer) -> None: + server.hooks.before_first_query_run = False + server.hooks.before_first_query[0].reset_mock() + server.hooks.before_query[0].reset_mock() + server.hooks.after_query[0].reset_mock() + return + + +def reset_all_hooks() -> None: + reset_hooks(nameserver) + reset_hooks(subserver_1) + reset_hooks(subserver_2) + reset_hooks(subserver_3) + return + + +@no_type_check +def check_hook_call_count(server: NameServer, bfq_count: int, bq_count: int, aq_count: int) -> None: + assert server.hooks.before_first_query[0].call_count == bfq_count + assert server.hooks.before_query[0].call_count == bq_count + assert server.hooks.after_query[0].call_count == aq_count + return + + +register_hooks(nameserver) +register_hooks(subserver_1) +register_hooks(subserver_2) +register_hooks(subserver_3) + + +## Exception handling +## ----------------------------------------------------------------------------- +class ErrorForTesting(Exception): + pass + + +@nameserver.rule("throw-error.com", ["A"]) +def throw_error(query: Query) -> None: + raise ErrorForTesting() + + +def _query_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + return Response(error_code=dnslib.RCODE.SERVFAIL) + + +query_error_handler = unittest.mock.MagicMock(wraps=_query_error_handler) +nameserver.register_exception_handler(ErrorForTesting, query_error_handler) + + +class ThrowAnotherError(Exception): + pass + + +@nameserver.rule("throw-another-error.com", ["A"]) +def throw_another_error(query: Query) -> None: + raise ThrowAnotherError() + + +def bad_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + raise ErrorForTesting() + + +nameserver.register_exception_handler(ThrowAnotherError, bad_error_handler) + + +## Get server ready +## ----------------------------------------------------------------------------- +nameserver.register_subserver(subserver_1, ZoneRule, "sub1.com", ALL_QTYPES) +nameserver.register_subserver(subserver_2, ZoneRule, "sub2.com", ALL_QTYPES) +subserver_2.register_subserver(subserver_3, ZoneRule, "sub3.sub2.com", ALL_QTYPES) + + +### TESTS +### ============================================================================ +## Responses +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize("question", ["s.com", "sub1.com", "sub2.com", "sub3.sub2.com"]) +def test_response(question: str): + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + return + + +@pytest.mark.parametrize( + "question", ["miss.s.com", "miss.sub1.com", "miss.sub2.com", "miss.sub3.sub2.com"] +) +def test_nxdomain(question: str): + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 0 + assert response.header.rcode == dnslib.RCODE.NXDOMAIN + return + + +## Hooks +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "question,hook_counts", + [ + ("s.com", [1, 5, 5]), + ("sub1.com", [1, 5, 5, 1, 5, 5]), + ("sub2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5]), + ("sub3.sub2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5, 1, 5, 5]), + ], +) +def test_hooks(question: str, hook_counts: List[int]): + ## Setup + # fill unset hook_counts + hook_counts += [0] * (12 - len(hook_counts)) + assert len(hook_counts) == 12 + # reset hooks + reset_all_hooks() + + ## Test + for _ in range(5): + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + + check_hook_call_count(nameserver, *hook_counts[:3]) + check_hook_call_count(subserver_1, *hook_counts[3:6]) + check_hook_call_count(subserver_2, *hook_counts[6:9]) + check_hook_call_count(subserver_3, *hook_counts[9:]) + return diff --git a/tox.ini b/tox.ini index 5e7f556..0aaa4fb 100644 --- a/tox.ini +++ b/tox.ini @@ -1,10 +1,24 @@ [tox] -envlist = py37,py38,py39,py310,py311,py312,pypy37,pypy38,pypy39 +requires = tox>=3,tox-uv +envlist = pypy{38,39,310}, py{38,39,310,311,312,313} [testenv] -package = external -deps = pytest -commands = {posargs:pytest -ra tests} +description = run unit tests +extras = dev +commands = + pytest tests -[testenv:.pkg_external] -package_glob = /code/dist/* +[testenv:format] +description = run formatters +extras = dev +commands = + black src tests + +[testenv:lint] +description = run linters +extras = dev +commands = + validate-pyproject pyproject.toml + black --check --diff src tests + pylint src + mypy src tests