From 75edfb4dc95fdc81e25372281da31382461f9285 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 20 Dec 2023 17:57:54 +1100 Subject: [PATCH] [server] Add blueprints, also make this v2.0.0 --- docs/blueprints.md | 54 +++++ docs/changelog.md | 4 +- docs/error-handling.md | 2 +- docs/middleware.md | 2 +- mkdocs.yml | 1 + pyproject.toml | 2 +- src/nserver/__init__.py | 4 +- src/nserver/middleware.py | 16 +- src/nserver/rules.py | 10 +- src/nserver/server.py | 485 +++++++++++++++++++++----------------- tests/test_blueprint.py | 184 +++++++++++++++ tests/test_rules.py | 96 ++++---- 12 files changed, 580 insertions(+), 280 deletions(-) create mode 100644 docs/blueprints.md create mode 100644 tests/test_blueprint.py diff --git a/docs/blueprints.md b/docs/blueprints.md new file mode 100644 index 0000000..90462ec --- /dev/null +++ b/docs/blueprints.md @@ -0,0 +1,54 @@ +# Blueprints + +[`Blueprint`][nserver.server.Blueprint]s provide a way for you to compose your application. They support most of the same functionality as a `NameServer`. + +Use cases: + +- Split up your application across different blueprints for maintainability / composability. +- Reuse a blueprint registered under different rules. +- Allow custom packages to define their own rules that you can add to your own server. + +Blueprints require `nserver>=2.0` + +## Using Blueprints + +```python +from nserver import Blueprint, NameServer, ZoneRule, ALL_CTYPES, A + +# First Blueprint +mysite = Blueprint("mysite") + +@mysite.rule("nicholashairs.com", ["A"]) +@mysite.rule("www.nicholashairs.com", ["A"]) +def nicholashairs_website(query: Query) -> A: + return A(query.name, "159.65.13.73") + +@mysite.rule(ZoneRule, "", ALL_CTYPES) +def nicholashairs_catchall(query: Query) -> None: + # Return empty response for all other queries + return None + +# Second Blueprint +en_blueprint = Blueprint("english-speaking-blueprint") + +@en_blueprint.rule("hello.{base_domain}", ["A"]) +def en_hello(query: Query) -> A: + return A(query.name, "1.1.1.1") + +# Register to NameServer +server = NameServer("server") +server.register_blueprint(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) +server.register_blueprint(en_blueprint, ZoneRule, "au", ALL_CTYPES) +server.register_blueprint(en_blueprint, ZoneRule, "nz", ALL_CTYPES) +server.register_blueprint(en_blueprint, ZoneRule, "uk", ALL_CTYPES) +``` + +### Middleware, Hooks, and Error Handling + +Blueprints maintain their own `QueryMiddleware` stack which will run before any rule function is run. Included in this stack is the `HookMiddleware` and `ExceptionHandlerMiddleware`. + +## Key differences with `NameServer` + +- Does not use settings (`Setting`). +- Does not have a `Transport`. +- Does not have a `RawRecordMiddleware` stack. diff --git a/docs/changelog.md b/docs/changelog.md index fd35d2d..9f9f034 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,12 +1,14 @@ # Change Log -## 1.1.0 +## 2.0.0 - Implement [Middleware][middleware] - This includes adding error handling middleware that facilitates [error handling][error-handling]. - Add [`StaticRule`][nserver.rules.StaticRule] and [`ZoneRule`][nserver.rules.ZoneRule]. - Refector [`NameServer.rule`][nserver.server.NameServer.rule] to use expanded [`smart_make_rule`][nserver.rules.smart_make_rule] function. - Although this change should not affect rules using this decorator from being called correctly, it may change the precise rule type being used. Specifically it may use `StaticRule` instead of `WildcardStringRule` for strings with no substitutions. +- Add [Blueprints][blueprints] + - Include refactoring `NameServer` into a new shared based `Scaffold` class. ## 1.0.0 diff --git a/docs/error-handling.md b/docs/error-handling.md index fdf60dc..a27c2e5 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -2,7 +2,7 @@ Custom exception handling is handled through the [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] and [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. -Error handling requires `nserver>=1.1.0` +Error handling requires `nserver>=2.0` In general you are probably able to use the `ExceptionHandlerMiddleware` as the `RawRecordExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawRecordMiddleware` or broken exception handlers in the `ExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `ExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `ExceptionHandlerMiddleware`. diff --git a/docs/middleware.md b/docs/middleware.md index 1e74d79..c54fc85 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -2,7 +2,7 @@ Middleware can be used to modify the behaviour of a server seperate to the individual rules that are registered to the server. Middleware is run on all requests and can modify both the input and response of a request. -Middleware requires `nserver>=1.1.0` +Middleware requires `nserver>=2.0` ## Middleware Stacks diff --git a/mkdocs.yml b/mkdocs.yml index beb4f1a..fa88c1d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,6 +13,7 @@ nav: - quickstart.md - middleware.md - error-handling.md + - blueprints.md - production-deployment.md - changelog.md - external-resources.md diff --git a/pyproject.toml b/pyproject.toml index dccb166..7c9a83b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nserver" -version = "1.1.0.dev1" +version = "2.0.0" description = "DNS Name Server Framework" authors = [ {name = "Nicholas Hairs", email = "info+nserver@nicholashairs.com"}, diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index 7844bde..54e8c75 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,5 +1,5 @@ from .models import Query, Response -from .rules import StaticRule, ZoneRule, RegexRule, WildcardStringRule +from .rules import ALL_QTYPES, StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA -from .server import NameServer +from .server import NameServer, Blueprint from .settings import Settings diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 722f9cf..2662bd7 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -46,7 +46,7 @@ def coerce_to_response(result: RuleResult) -> Response: """Convert some `RuleResult` to a `Response` - New in `1.1.0`. + New in `2.0`. Args: result: the results to convert @@ -76,7 +76,7 @@ def coerce_to_response(result: RuleResult) -> Response: class QueryMiddleware: """Middleware for interacting with `Query` objects - New in `1.1.0`. + New in `2.0`. """ def __init__(self) -> None: @@ -118,7 +118,7 @@ class ExceptionHandlerMiddleware(QueryMiddleware): matches the class or parent class of the exception in method resolution order. If no handler is registered will use this classes `self.default_exception_handler`. - New in `1.1.0`. + New in `2.0`. Attributes: exception_handlers: registered exception handlers @@ -182,7 +182,7 @@ class HookMiddleware(QueryMiddleware): hook or from the next function in the middleware chain. They take a `Response` input and must return a `Response`. - New in `1.1.0`. + New in `2.0`. Attributes: before_first_query: `before_first_query` hooks @@ -257,7 +257,7 @@ class RuleProcessor: This class serves as the bottom of the `QueryMiddleware` stack. - New in `1.1.0`. + New in `2.0`. """ def __init__(self, rules: List[RuleBase]) -> None: @@ -284,7 +284,7 @@ def __call__(self, query: Query) -> Response: class RawRecordMiddleware: """Middleware to be run against raw `dnslib.DNSRecord`s. - New in `1.1.0`. + New in `2.0`. """ def __init__(self) -> None: @@ -332,7 +332,7 @@ class RawRecordExceptionHandlerMiddleware(RawRecordMiddleware): Exception handlers are expected to be robust - that is, they must always return correctly even if they internally encounter an `Exception`. - New in `1.1.0`. + New in `2.0`. Attributes: exception_handlers: registered exception handlers @@ -389,7 +389,7 @@ class QueryMiddlewareProcessor: This class serves as the bottom of the `RawRcordMiddleware` stack. - New in `1.1.0`. + New in `2.0`. """ def __init__(self, query_middleware: QueryMiddlewareCallable) -> None: diff --git a/src/nserver/rules.py b/src/nserver/rules.py index 4b521eb..3abf5bc 100644 --- a/src/nserver/rules.py +++ b/src/nserver/rules.py @@ -19,7 +19,7 @@ ALL_QTYPES: List[str] = list(dnslib.QTYPE.reverse.keys()) """All supported Query Types -New in `1.1.0`. +New in `2.0`. """ _wildcard_string_regex = re.compile(r"[*]|\{base_domain\}") @@ -39,7 +39,7 @@ def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs it will be a `WildcardStringRule`, else a `StaticRule`. - `Pattern` then a `RegexRule`. - New in `1.1.0` + New in `2.0` Args: rule: input to process @@ -57,7 +57,7 @@ def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs # vary between versions of python and other bugs. # see also: https://stackoverflow.com/questions/6102019/type-of-compiled-regex-object-in-python return RegexRule(rule, *args, **kwargs) - raise ValueError(f"Could not handle rule: {rule!r}") + return rule(*args, **kwargs) ### CLASSES @@ -93,7 +93,7 @@ class StaticRule(RuleBase): `StaticRule` is more efficient than using a `WildcardStringRule` for static strings. - New in `1.1.0`. + New in `2.0`. """ def __init__( @@ -141,7 +141,7 @@ class ZoneRule(RuleBase): An empty zone (`""`) will match any domain as this refers to the domain root (`.`). - New in `1.1.0`. + New in `2.0`. """ def __init__( diff --git a/src/nserver/server.py b/src/nserver/server.py index ed9bc0c..8c69b61 100644 --- a/src/nserver/server.py +++ b/src/nserver/server.py @@ -4,13 +4,14 @@ import logging # Note: Optional can only be replaced with `| None` in 3.10+ -from typing import List, Dict, Pattern, Optional, Union, Type +from typing import List, Dict, Optional, Union, Type, Pattern ## Installed import dnslib ## Application from .exceptions import InvalidMessageError +from .models import Query, Response from .rules import smart_make_rule, RuleBase, ResponseFunction from .settings import Settings from .transport import TransportBase, UDPv4Transport, UDPv6Transport, TCPv4Transport @@ -28,49 +29,38 @@ ### Classes ### ============================================================================ -class NameServer: - """NameServer for responding to requests.""" +class Scaffold: + """Base class for shared functionality between `NameServer` and `Blueprint` - # pylint: disable=too-many-instance-attributes + New in `2.0`. - def __init__(self, name: str, settings: Optional[Settings] = None) -> None: - """Initialise NameServer + Attributes: + rules: registered rules + hook_middleware: hook middleware + exception_handler_middleware: Query exception handler middleware + """ + + _logger: logging.Logger + def __init__(self, name: str) -> None: + """ Args: name: The name of the server. This is used for internal logging. - settings: settings to use with this `NameServer` instance """ self.name = name - self.rules: List[RuleBase] = [] - self._logger = logging.getLogger(f"nserver.i.{self.name}") + self.rules: List[RuleBase] = [] self.hook_middleware = middleware.HookMiddleware() self.exception_handler_middleware = middleware.ExceptionHandlerMiddleware() - self.raw_exception_handler_middleware = middleware.RawRecordExceptionHandlerMiddleware() self._user_query_middleware: List[middleware.QueryMiddleware] = [] self._query_middleware_stack: List[ Union[middleware.QueryMiddleware, middleware.QueryMiddlewareCallable] ] = [] - - self._user_raw_record_middleware: List[middleware.RawRecordMiddleware] = [] - self._raw_record_middleware_stack: List[ - Union[middleware.RawRecordMiddleware, middleware.RawRecordMiddlewareCallable] - ] = [] - - self.settings = settings if settings is not None else Settings() - - transport = TRANSPORT_MAP.get(self.settings.server_transport) - if transport is None: - raise ValueError( - f"Invalid settings.server_transport {self.settings.server_transport!r}" - ) - self.transport = transport(self.settings) - - self.shutdown_server = False - self.exit_code = 0 return + ## Register Methods + ## ------------------------------------------------------------------------- def register_rule(self, rule: RuleBase) -> None: """Register the given rule @@ -81,24 +71,27 @@ def register_rule(self, rule: RuleBase) -> None: self.rules.append(rule) return - # def register_blueprint(self, blueprint, rule: Union[str, Pattern, RuleBase]) -> None: - # """Register a blueprint using the given rule. - # - # - # If the rule triggers, the query is passed to the Blueprint to determine - # if a rule matches. Just because a rule matches the blueprint does not - # mean that the rule will match any rule in the blueprint. - # - # Args: - # blueprint: the `Blueprint` to attach - # rule: The rule to use to match to this `blueprint` - # - If rule is a `str` is interpreted as a the input for a WildcardStringRule. - # - If rule is a `Pattern` is interpreted as the input for a RegexRule. - # - If rule is a instance of `RuleBase` is used as is. - # - # Note: that all rules are internally converted to a `BlueprintRule`. - # """ - # raise NotImplementedError() + def register_blueprint( + self, blueprint: "Blueprint", rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs + ) -> None: + """Register a blueprint using [`smart_make_rule`][nserver.rules.smart_make_rule]. + + New in `2.0`. + + Args: + blueprint: the `Blueprint` to attach + rule_: rule as per `nserver.rules.smart_make_rule` + args: extra arguments to provide `smart_make_rule` + kwargs: extra keyword arguments to provide `smart_make_rule` + + Raises: + ValueError: if `func` is provided in `kwargs`. + """ + + if "func" in kwargs: + raise ValueError("Must not provide `func` in kwargs") + self.register_rule(smart_make_rule(rule_, *args, func=blueprint.entrypoint, **kwargs)) + return def register_before_first_query(self, func: middleware.BeforeFirstQueryHook) -> None: """Register a function to be run before the first query. @@ -132,7 +125,7 @@ def register_after_query(self, func: middleware.AfterQueryHook) -> None: def register_middleware(self, query_middleware: middleware.QueryMiddleware) -> None: """Add a `QueryMiddleware` to this server. - New in `1.1.0`. + New in `2.0`. Args: query_middleware: the middleware to add @@ -144,21 +137,6 @@ def register_middleware(self, query_middleware: middleware.QueryMiddleware) -> N self._user_query_middleware.append(query_middleware) return - def register_raw_middleware(self, raw_middleware: middleware.RawRecordMiddleware) -> None: - """Add a `RawRecordMiddleware` to this server. - - New in `1.1.0`. - - Args: - raw_middleware: the middleware to add - """ - if self._raw_record_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Cannot register middleware after stack is created") - self._user_raw_record_middleware.append(raw_middleware) - return - def register_exception_handler( self, exception_class: Type[Exception], handler: middleware.ExceptionHandler ) -> None: @@ -166,7 +144,7 @@ def register_exception_handler( Only one handler can exist for a given exception type. - New in `1.1.0`. + New in `2.0`. Args: exception_class: the type of exception to handle @@ -178,103 +156,17 @@ def register_exception_handler( self.exception_handler_middleware.exception_handlers[exception_class] = handler return - def register_raw_exception_handler( - self, exception_class: Type[Exception], handler: middleware.RawRecordExceptionHandler - ) -> None: - """Register a raw exception handler for the `RawRecordMiddleware`. - - Only one handler can exist for a given exception type. - - New in `1.1.0`. - - Args: - exception_class: the type of exception to handle - handler: the function to call when handling an exception - """ - if exception_class in self.raw_exception_handler_middleware.exception_handlers: - raise ValueError("Exception handler already exists for {exception_class}") - - self.raw_exception_handler_middleware.exception_handlers[exception_class] = handler - return - - def run(self) -> int: - """Start running the server - - Returns: - `exit_code`, `0` if exited normally - """ - # Setup Logging - console_logger = logging.StreamHandler() - console_logger.setLevel(self.settings.console_log_level) - - console_formatter = logging.Formatter( - "[{asctime}][{levelname}][{name}] {message}", style="{" - ) - - console_logger.setFormatter(console_formatter) - - self._logger.addHandler(console_logger) - self._logger.setLevel(min(self.settings.console_log_level, self.settings.file_log_level)) - - # Start Server - # TODO: Do we want to recreate the transport instance or do we assume that - # transport.shutdown_server puts it back into a ready state? - # We could make this configurable? :thonking: - - self._info(f"Starting {self.transport}") - try: - self._prepare_middleware_stacks() - self.transport.start_server() - except Exception as e: # pylint: disable=broad-except - self._critical(e) - self.exit_code = 1 - return self.exit_code - - # Process Requests - error_count = 0 - while True: - if self.shutdown_server: - break - try: - message = self.transport.receive_message() - response = self._process_dns_record(message.message) - message.response = response - self.transport.send_message_response(message) - except InvalidMessageError as e: - self._warning(f"{e}") - except Exception as e: # pylint: disable=broad-except - self._error(f"Uncaught error occured. {e}", exc_info=True) - error_count += 1 - if error_count >= self.settings.max_errors: - self._critical(f"Max errors hit ({error_count})") - self.shutdown_server = True - self.exit_code = 1 - except KeyboardInterrupt: - self._info("KeyboardInterrupt received.") - self.shutdown_server = True - - # Stop Server - self._info("Shutting down server") - self.transport.stop_server() - - # Teardown Logging - self._logger.removeHandler(console_logger) - return self.exit_code - - ## Decorators - ## ------------------------------------------------------------------------- + # Decorators + # .......................................................................... def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): - """Decorator for registering a function using an appropriate rule. + """Decorator for registering a function using [`smart_make_rule`][nserver.rules.smart_make_rule]. - Changed in `1.1.0`: This function now uses [`smart_make_rule`][nserver.rules.smart_make_rule]. - At the time of writing this allows for Rule classes be used directly, - and `str` inputs may result in a `StaticRule`. Future changes to - `smart_make_rule` will not be documented here. + Changed in `2.0`: This method now uses `smart_make_rule`. Args: rule_: rule as per `nserver.rules.smart_make_rule` - args: extra arguments to provide - kwargs: extra keyword arguments to provide + args: extra arguments to provide `smart_make_rule` + kwargs: extra keyword arguments to provide `smart_make_rule` Raises: ValueError: if `func` is provided in `kwargs`. @@ -333,7 +225,7 @@ def decorator(func: middleware.AfterQueryHook): def exception_handler(self, exception_class: Type[Exception]): """Decorator for registering a function as an exception handler - New in `1.1.0`. + New in `2.0`. Args: exception_class: The `Exception` class to register this handler for @@ -346,10 +238,147 @@ def decorator(func: middleware.ExceptionHandler): return decorator + ## Internal Functions + ## ------------------------------------------------------------------------- + def _prepare_query_middleware_stack(self) -> None: + """Prepare the `QueryMiddleware` for this server.""" + if self._query_middleware_stack: + # Note: we can use truthy expression as once processed there will always be at + # least one item in the stack + raise RuntimeError("QueryMiddleware stack already exists") + + middleware_stack: List[middleware.QueryMiddleware] = [ + self.exception_handler_middleware, + *self._user_query_middleware, + self.hook_middleware, + ] + rule_processor = middleware.RuleProcessor(self.rules) + + next_middleware: Optional[middleware.QueryMiddleware] = None + for query_middleware in middleware_stack[::-1]: + if next_middleware is None: + query_middleware.register_next_function(rule_processor) + else: + query_middleware.register_next_function(next_middleware) + next_middleware = query_middleware + + self._query_middleware_stack.extend(middleware_stack) + self._query_middleware_stack.append(rule_processor) + return + + ## Logging + ## ------------------------------------------------------------------------- + def _vvdebug(self, *args, **kwargs): + """Log very verbose debug message.""" + + return self._logger.log(6, *args, **kwargs) + + def _vdebug(self, *args, **kwargs): + """Log verbose debug message.""" + + return self._logger.log(8, *args, **kwargs) + + def _debug(self, *args, **kwargs): + """Log debug message.""" + + return self._logger.debug(*args, **kwargs) + + def _info(self, *args, **kwargs): + """Log very verbose debug message.""" + + return self._logger.info(*args, **kwargs) + + def _warning(self, *args, **kwargs): + """Log warning message.""" + + return self._logger.warning(*args, **kwargs) + + def _error(self, *args, **kwargs): + """Log an error message.""" + + return self._logger.error(*args, **kwargs) + + def _critical(self, *args, **kwargs): + """Log a critical message.""" + + return self._logger.critical(*args, **kwargs) + + +class NameServer(Scaffold): + """NameServer for responding to requests.""" + + # pylint: disable=too-many-instance-attributes + + def __init__(self, name: str, settings: Optional[Settings] = None) -> None: + """ + Args: + name: The name of the server. This is used for internal logging. + settings: settings to use with this `NameServer` instance + """ + super().__init__(name) + self._logger = logging.getLogger(f"nserver.i.{self.name}") + + self.raw_exception_handler_middleware = middleware.RawRecordExceptionHandlerMiddleware() + self._user_raw_record_middleware: List[middleware.RawRecordMiddleware] = [] + self._raw_record_middleware_stack: List[ + Union[middleware.RawRecordMiddleware, middleware.RawRecordMiddlewareCallable] + ] = [] + + self.settings = settings if settings is not None else Settings() + + transport = TRANSPORT_MAP.get(self.settings.server_transport) + if transport is None: + raise ValueError( + f"Invalid settings.server_transport {self.settings.server_transport!r}" + ) + self.transport = transport(self.settings) + + self.shutdown_server = False + self.exit_code = 0 + return + + ## Register Methods + ## ------------------------------------------------------------------------- + def register_raw_middleware(self, raw_middleware: middleware.RawRecordMiddleware) -> None: + """Add a `RawRecordMiddleware` to this server. + + New in `2.0`. + + Args: + raw_middleware: the middleware to add + """ + if self._raw_record_middleware_stack: + # Note: we can use truthy expression as once processed there will always be at + # least one item in the stack + raise RuntimeError("Cannot register middleware after stack is created") + self._user_raw_record_middleware.append(raw_middleware) + return + + def register_raw_exception_handler( + self, exception_class: Type[Exception], handler: middleware.RawRecordExceptionHandler + ) -> None: + """Register a raw exception handler for the `RawRecordMiddleware`. + + Only one handler can exist for a given exception type. + + New in `2.0`. + + Args: + exception_class: the type of exception to handle + handler: the function to call when handling an exception + """ + if exception_class in self.raw_exception_handler_middleware.exception_handlers: + raise ValueError("Exception handler already exists for {exception_class}") + + self.raw_exception_handler_middleware.exception_handlers[exception_class] = handler + return + + # Decorators + # .......................................................................... def raw_exception_handler(self, exception_class: Type[Exception]): """Decorator for registering a function as an raw exception handler - New in `1.1.0`. + New in `2.0`. Args: exception_class: The `Exception` class to register this handler for @@ -362,6 +391,72 @@ def decorator(func: middleware.RawRecordExceptionHandler): return decorator + ## Public Methods + ## ------------------------------------------------------------------------- + def run(self) -> int: + """Start running the server + + Returns: + `exit_code`, `0` if exited normally + """ + # Setup Logging + console_logger = logging.StreamHandler() + console_logger.setLevel(self.settings.console_log_level) + + console_formatter = logging.Formatter( + "[{asctime}][{levelname}][{name}] {message}", style="{" + ) + + console_logger.setFormatter(console_formatter) + + self._logger.addHandler(console_logger) + self._logger.setLevel(min(self.settings.console_log_level, self.settings.file_log_level)) + + # Start Server + # TODO: Do we want to recreate the transport instance or do we assume that + # transport.shutdown_server puts it back into a ready state? + # We could make this configurable? :thonking: + + self._info(f"Starting {self.transport}") + try: + self._prepare_middleware_stacks() + self.transport.start_server() + except Exception as e: # pylint: disable=broad-except + self._critical(e) + self.exit_code = 1 + return self.exit_code + + # Process Requests + error_count = 0 + while True: + if self.shutdown_server: + break + try: + message = self.transport.receive_message() + response = self._process_dns_record(message.message) + message.response = response + self.transport.send_message_response(message) + except InvalidMessageError as e: + self._warning(f"{e}") + except Exception as e: # pylint: disable=broad-except + self._error(f"Uncaught error occured. {e}", exc_info=True) + error_count += 1 + if error_count >= self.settings.max_errors: + self._critical(f"Max errors hit ({error_count})") + self.shutdown_server = True + self.exit_code = 1 + except KeyboardInterrupt: + self._info("KeyboardInterrupt received.") + self.shutdown_server = True + + # Stop Server + self._info("Shutting down server") + self.transport.stop_server() + + # Teardown Logging + self._logger.removeHandler(console_logger) + return self.exit_code + ## Internal Functions ## ------------------------------------------------------------------------- def _process_dns_record(self, message: dnslib.DNSRecord) -> dnslib.DNSRecord: @@ -381,36 +476,10 @@ def _process_dns_record(self, message: dnslib.DNSRecord) -> dnslib.DNSRecord: def _prepare_middleware_stacks(self) -> None: """Prepare all middleware for this server.""" - self._prepare_request_middleware_stack() + self._prepare_query_middleware_stack() self._prepare_raw_record_middleware_stack() return - def _prepare_request_middleware_stack(self) -> None: - """Prepare the `QueryMiddleware` for this server.""" - if self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("QueryMiddleware stack already exists") - - middleware_stack: List[middleware.QueryMiddleware] = [ - self.exception_handler_middleware, - *self._user_query_middleware, - self.hook_middleware, - ] - rule_processor = middleware.RuleProcessor(self.rules) - - next_middleware: Optional[middleware.QueryMiddleware] = None - for query_middleware in middleware_stack[::-1]: - if next_middleware is None: - query_middleware.register_next_function(rule_processor) - else: - query_middleware.register_next_function(next_middleware) - next_middleware = query_middleware - - self._query_middleware_stack.extend(middleware_stack) - self._query_middleware_stack.append(rule_processor) - return - def _prepare_raw_record_middleware_stack(self) -> None: """Prepare the `RawRecordMiddleware` for this server.""" if not self._query_middleware_stack: @@ -444,39 +513,29 @@ def _prepare_raw_record_middleware_stack(self) -> None: self._raw_record_middleware_stack.append(query_middleware_processor) return - ## Logging - ## ------------------------------------------------------------------------- - def _vvdebug(self, *args, **kwargs): - """Log very verbose debug message.""" - - return self._logger.log(6, *args, **kwargs) - - def _vdebug(self, *args, **kwargs): - """Log verbose debug message.""" - return self._logger.log(8, *args, **kwargs) +class Blueprint(Scaffold): + """Class that can replicate many of the functions of a `NameServer`. - def _debug(self, *args, **kwargs): - """Log debug message.""" + They can be used to construct or extend applications. - return self._logger.debug(*args, **kwargs) + New in `2.0`. + """ - def _info(self, *args, **kwargs): - """Log very verbose debug message.""" - - return self._logger.info(*args, **kwargs) - - def _warning(self, *args, **kwargs): - """Log warning message.""" - - return self._logger.warning(*args, **kwargs) - - def _error(self, *args, **kwargs): - """Log an error message.""" - - return self._logger.error(*args, **kwargs) + def __init__(self, name: str) -> None: + """ + Args: + name: The name of the server. This is used for internal logging. + """ + super().__init__(name) + self._logger = logging.getLogger(f"nserver.b.{self.name}") + return - def _critical(self, *args, **kwargs): - """Log a critical message.""" + def entrypoint(self, query: Query) -> Response: + """Entrypoint into this `Blueprint`. - return self._logger.critical(*args, **kwargs) + This method should be passed to rules as the function to run. + """ + if not self._query_middleware_stack: + self._prepare_query_middleware_stack() + return self._query_middleware_stack[0](query) diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py new file mode 100644 index 0000000..e801dc8 --- /dev/null +++ b/tests/test_blueprint.py @@ -0,0 +1,184 @@ +# pylint: disable=missing-class-docstring,missing-function-docstring,protected-access + +### IMPORTS +### ============================================================================ +## Standard Library +from typing import no_type_check, List +import unittest.mock + +## Installed +import dnslib +import pytest + +from nserver import NameServer, Blueprint, Query, Response, ALL_QTYPES, ZoneRule, A +from nserver.server import Scaffold + +## Application + +### SETUP +### ============================================================================ +IP = "127.0.0.1" +server = NameServer("test_blueprint") +blueprint_1 = Blueprint("blueprint_1") +blueprint_2 = Blueprint("blueprint_2") +blueprint_3 = Blueprint("blueprint_3") + + +## Rules +## ----------------------------------------------------------------------------- +@server.rule("s.com", ["A"]) +@blueprint_1.rule("b1.com", ["A"]) +@blueprint_2.rule("b2.com", ["A"]) +@blueprint_3.rule("b3.b2.com", ["A"]) +def dummy_rule(query: Query) -> A: + return A(query.name, IP) + + +## Hooks +## ----------------------------------------------------------------------------- +def register_hooks(scaff: Scaffold) -> None: + scaff.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) + scaff.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) + scaff.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) + return + + +@no_type_check +def reset_hooks(scaff: Scaffold) -> None: + scaff.hook_middleware.before_first_query_run = False + scaff.hook_middleware.before_first_query[0].reset_mock() + scaff.hook_middleware.before_query[0].reset_mock() + scaff.hook_middleware.after_query[0].reset_mock() + return + + +def reset_all_hooks() -> None: + reset_hooks(server) + reset_hooks(blueprint_1) + reset_hooks(blueprint_2) + reset_hooks(blueprint_3) + return + + +@no_type_check +def check_hook_call_count(scaff: Scaffold, bfq_count: int, bq_count: int, aq_count: int) -> None: + assert scaff.hook_middleware.before_first_query[0].call_count == bfq_count + assert scaff.hook_middleware.before_query[0].call_count == bq_count + assert scaff.hook_middleware.after_query[0].call_count == aq_count + return + + +register_hooks(server) +register_hooks(blueprint_1) +register_hooks(blueprint_2) +register_hooks(blueprint_3) + + +## Exception handling +## ----------------------------------------------------------------------------- +class ErrorForTesting(Exception): + pass + + +@server.rule("throw-error.com", ["A"]) +def throw_error(query: Query) -> None: + raise ErrorForTesting() + + +def _query_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + return Response(error_code=dnslib.RCODE.SERVFAIL) + + +query_error_handler = unittest.mock.MagicMock(wraps=_query_error_handler) +server.register_exception_handler(ErrorForTesting, query_error_handler) + + +class ThrowAnotherError(Exception): + pass + + +@server.rule("throw-another-error.com", ["A"]) +def throw_another_error(query: Query) -> None: + raise ThrowAnotherError() + + +def bad_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + raise ErrorForTesting() + + +server.register_exception_handler(ThrowAnotherError, bad_error_handler) + + +def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> dnslib.DNSRecord: + # pylint: disable=unused-argument + response = record.reply() + response.header.rcode = dnslib.RCODE.SERVFAIL + return response + + +raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) +server.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) + +## Get server ready +## ----------------------------------------------------------------------------- +server.register_blueprint(blueprint_1, ZoneRule, "b1.com", ALL_QTYPES) +server.register_blueprint(blueprint_2, ZoneRule, "b2.com", ALL_QTYPES) +blueprint_2.register_blueprint(blueprint_3, ZoneRule, "b3.b2.com", ALL_QTYPES) + +server._prepare_middleware_stacks() + + +### TESTS +### ============================================================================ +## Responses +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize("question", ["s.com", "b1.com", "b2.com", "b3.b2.com"]) +def test_response(question: str): + response = server._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + return + + +@pytest.mark.parametrize("question", ["miss.s.com", "miss.b1.com", "miss.b2.com", "miss.b3.b2.com"]) +def test_nxdomain(question: str): + response = server._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 0 + assert response.header.rcode == dnslib.RCODE.NXDOMAIN + return + + +## Hooks +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "question,hook_counts", + [ + ("s.com", [1, 5, 5]), + ("b1.com", [1, 5, 5, 1, 5, 5]), + ("b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5]), + ("b3.b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5, 1, 5, 5]), + ], +) +def test_hooks(question: str, hook_counts: List[int]): + ## Setup + # fill unset hook_counts + hook_counts += [0] * (12 - len(hook_counts)) + assert len(hook_counts) == 12 + # reset hooks + reset_all_hooks() + + ## Test + for _ in range(5): + response = server._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + + check_hook_call_count(server, *hook_counts[:3]) + check_hook_call_count(blueprint_1, *hook_counts[3:6]) + check_hook_call_count(blueprint_2, *hook_counts[6:9]) + check_hook_call_count(blueprint_3, *hook_counts[9:]) + return diff --git a/tests/test_rules.py b/tests/test_rules.py index f095d9f..f9e33ec 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -38,7 +38,7 @@ def run_rule(rule, query, matches): ### ============================================================================ @pytest.mark.parametrize( "rule,expected", - ( + [ ("*", True), ("**", True), ("{base_domain}", True), @@ -47,7 +47,7 @@ def run_rule(rule, query, matches): ("", False), ("foo.com", False), ("{something}", False), - ), + ], ) def test_wildcard_string_regex(rule, expected): assert bool(_wildcard_string_regex.search(rule)) is expected @@ -55,7 +55,7 @@ def test_wildcard_string_regex(rule, expected): @pytest.mark.parametrize( "rule,expected", - ( + [ ("", StaticRule), ("foo", StaticRule), ("example.com", StaticRule), @@ -69,7 +69,7 @@ def test_wildcard_string_regex(rule, expected): ("*.{base_domain}", WildcardStringRule), ("*.{base_domain}", WildcardStringRule), (re.compile(".*"), RegexRule), - ), + ], ) def test_smart_make_rule_class(rule, expected): assert isinstance(smart_make_rule(rule, ["A"], func=DUMMY_FUNCTION), expected) @@ -80,11 +80,11 @@ def test_smart_make_rule_class(rule, expected): class TestStaticRule: @pytest.mark.parametrize( "qtype,matches", - ( + [ ("A", True), ("AAAA", True), ("TXT", False), - ), + ], ) def test_qtypes(self, qtype, matches): rule = StaticRule("", ["A", "AAAA"], DUMMY_FUNCTION) @@ -94,7 +94,7 @@ def test_qtypes(self, qtype, matches): @pytest.mark.parametrize("match_string", ("test.com", "TEST.com", "test.COM", "TeSt.CoM")) @pytest.mark.parametrize( "name,matches", - ( + [ ("test.com", True), ("TEST.com", True), ("test.COM", True), @@ -103,7 +103,7 @@ def test_qtypes(self, qtype, matches): ("", False), ("com", False), ("foo.test.com", False), - ), + ], ) def test_case_insensitive(self, match_string, name, matches): rule = StaticRule(match_string, ["A"], DUMMY_FUNCTION, False) @@ -112,7 +112,7 @@ def test_case_insensitive(self, match_string, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("test.com", True), ("TEST.com", False), ("test.COM", False), @@ -121,7 +121,7 @@ def test_case_insensitive(self, match_string, name, matches): ("", False), ("com", False), ("foo.test.com", False), - ), + ], ) def test_case_sensitive(self, name, matches): rule = StaticRule("test.com", ["A"], DUMMY_FUNCTION, True) @@ -134,11 +134,11 @@ def test_case_sensitive(self, name, matches): class TestZoneRule: @pytest.mark.parametrize( "qtype,matches", - ( + [ ("A", True), ("AAAA", True), ("TXT", False), - ), + ], ) def test_qtypes(self, qtype, matches): rule = ZoneRule("", ["A", "AAAA"], DUMMY_FUNCTION) @@ -148,7 +148,7 @@ def test_qtypes(self, qtype, matches): @pytest.mark.parametrize("zone", ("test.com", "TEST.com", "test.COM", "TeSt.CoM")) @pytest.mark.parametrize( "name,matches", - ( + [ ("test.com", True), ("TEST.com", True), ("test.COM", True), @@ -160,7 +160,7 @@ def test_qtypes(self, qtype, matches): ("__dmarc.TeSt.CoM", True), ("", False), ("com", False), - ), + ], ) def test_case_insensitive(self, zone, name, matches): rule = ZoneRule(zone, ["A"], DUMMY_FUNCTION, False) @@ -169,7 +169,7 @@ def test_case_insensitive(self, zone, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("test.com", True), ("foo.test.com", True), ("FOO.test.com", True), @@ -182,7 +182,7 @@ def test_case_insensitive(self, zone, name, matches): ("TeSt.CoM", False), ("", False), ("com", False), - ), + ], ) def test_case_sensitive(self, name, matches): rule = ZoneRule("test.com", ["A"], DUMMY_FUNCTION, True) @@ -195,11 +195,11 @@ def test_case_sensitive(self, name, matches): class TestRegexRule: @pytest.mark.parametrize( "qtype,matches", - ( + [ ("A", True), ("AAAA", True), ("TXT", False), - ), + ], ) def test_qtypes(self, qtype, matches): rule = RegexRule(re.compile(".*"), ["A", "AAAA"], DUMMY_FUNCTION) @@ -208,7 +208,7 @@ def test_qtypes(self, qtype, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.test.com", True), ("cats.test.com", True), ("cat.kitten.test.com", True), @@ -217,7 +217,7 @@ def test_qtypes(self, qtype, matches): ("cat.test.coms", False), ("dog.test.com", False), ("dog.cat.test.com", False), - ), + ], ) def test_case_insensitive_same_case(self, name, matches): rule = RegexRule(re.compile(r"cat.*\.test\.com"), ["A"], DUMMY_FUNCTION, False) @@ -226,12 +226,12 @@ def test_case_insensitive_same_case(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("Cat.TEST.com", True), ("Cats.TEST.com", True), ("Cat.kitten.TEST.com", True), ("Cats.kittens.TEST.com", True), - ), + ], ) def test_case_insensitive_query_mixed(self, name, matches): rule = RegexRule(re.compile(r"cat.*\.test\.com"), ["A"], DUMMY_FUNCTION, False) @@ -240,7 +240,7 @@ def test_case_insensitive_query_mixed(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.test.com", True), ("cats.test.com", True), ("cat.kitten.test.com", True), @@ -253,7 +253,7 @@ def test_case_insensitive_query_mixed(self, name, matches): ("cat.test.coms", False), ("dog.test.com", False), ("dog.cat.test.com", False), - ), + ], ) def test_case_insensitive_regex_mixed(self, name, matches): rule = RegexRule(re.compile(r"Cat.*\.TEST\.com"), ["A"], DUMMY_FUNCTION, False) @@ -262,7 +262,7 @@ def test_case_insensitive_regex_mixed(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.test.com", False), ("cats.test.com", False), ("cat.kitten.test.com", False), @@ -275,7 +275,7 @@ def test_case_insensitive_regex_mixed(self, name, matches): ("cat.test.coms", False), ("dog.test.com", False), ("dog.cat.test.com", False), - ), + ], ) def test_case_sensitive(self, name, matches): rule = RegexRule(re.compile(r"Cat.*\.TEST\.com"), ["A"], DUMMY_FUNCTION, True) @@ -288,11 +288,11 @@ def test_case_sensitive(self, name, matches): class TestWildcardStringRule: @pytest.mark.parametrize( "qtype,matches", - ( + [ ("A", True), ("AAAA", True), ("TXT", False), - ), + ], ) def test_qtypes(self, qtype, matches): rule = WildcardStringRule("**", ["A", "AAAA"], DUMMY_FUNCTION) @@ -301,14 +301,14 @@ def test_qtypes(self, qtype, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.test.com", True), ("kitten.test.com", True), ("test.com", False), ("cat.fail.com", False), ("cat.test.fail", False), ("fail.cat.test.com", False), - ), + ], ) def test_single_wildcard_expansion(self, name, matches): rule = WildcardStringRule("*.test.com", ["A"], DUMMY_FUNCTION, False) @@ -317,13 +317,13 @@ def test_single_wildcard_expansion(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.kitten.test.com", True), ("lion.cat.kitten.test.com", True), ("test.com", False), ("cat.fail.com", False), ("cat.test.fail", False), - ), + ], ) def test_double_wildcard_expansion(self, name, matches): rule = WildcardStringRule("**.test.com", ["A"], DUMMY_FUNCTION, False) @@ -332,7 +332,7 @@ def test_double_wildcard_expansion(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.1.dog.1.test.com", True), ("cat.1.2.dog.1.test.com", True), ("cat.1.2.3.dog.1.test.com", True), @@ -340,7 +340,7 @@ def test_double_wildcard_expansion(self, name, matches): ("cat.dog.1.test.com", False), ("cat.1.2.dog.1.2.test.com", False), ("1.cat.3.dog.1.test.com", False), - ), + ], ) def test_multi_wildcard_expansion(self, name, matches): rule = WildcardStringRule("cat.**.dog.*.test.com", ["A"], DUMMY_FUNCTION, False) @@ -349,7 +349,7 @@ def test_multi_wildcard_expansion(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("internal", True), ("local", True), ("asdfasdfasdf", True), @@ -359,7 +359,7 @@ def test_multi_wildcard_expansion(self, name, matches): ("psl.au", True), ("nope.test.com", False), ("nope.foo.com.au", False), - ), + ], ) def test_base_domain_case_insensitive(self, name, matches): rule = WildcardStringRule("{base_domain}", ["A"], DUMMY_FUNCTION, False) @@ -368,7 +368,7 @@ def test_base_domain_case_insensitive(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("internal", True), ("local", True), ("asdfasdfasdf", True), @@ -388,7 +388,7 @@ def test_base_domain_case_insensitive(self, name, matches): ("NOPE.test.com", False), ("nope.TEST.com", False), ("nope.test.COM", False), - ), + ], ) def test_base_domain_case_sensitive(self, name, matches): rule = WildcardStringRule("{base_domain}", ["A"], DUMMY_FUNCTION, True) @@ -397,7 +397,7 @@ def test_base_domain_case_sensitive(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ # local domain ("cat.1.dog.1.internal", True), ("cat.1.2.dog.1.internal", True), @@ -422,7 +422,7 @@ def test_base_domain_case_sensitive(self, name, matches): ("cat.dog.1.etld.com.au", False), ("cat.1.2.dog.1.2.etld.com.au", False), ("1.cat.3.dog.1.etld.com.au", False), - ), + ], ) def test_base_domain_multi_wildcard_expansion(self, name, matches): rule = WildcardStringRule("cat.**.dog.*.{base_domain}", ["A"], DUMMY_FUNCTION, False) @@ -431,14 +431,14 @@ def test_base_domain_multi_wildcard_expansion(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.kitten.test.com", True), ("cats.dogs.test.com", False), ("cat.com", False), ("cat.test.coms", False), ("dog.test.com", False), ("dog.cat.test.com", False), - ), + ], ) def test_case_insensitive_same_case(self, name, matches): rule = WildcardStringRule("cat.**.test.com", ["A"], DUMMY_FUNCTION, False) @@ -447,7 +447,7 @@ def test_case_insensitive_same_case(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.kitten.test.com", True), ("cat.lion.kitten.test.com", True), ("cats.dogs.test.com", False), @@ -465,7 +465,7 @@ def test_case_insensitive_same_case(self, name, matches): ("Cat.TEST.coms", False), ("dog.TEST.com", False), ("dog.Cat.TEST.com", False), - ), + ], ) def test_case_insensitive_query_mixed(self, name, matches): rule = WildcardStringRule("cat.**.test.com", ["A"], DUMMY_FUNCTION, False) @@ -474,7 +474,7 @@ def test_case_insensitive_query_mixed(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.kitten.test.com", True), ("cat.lion.kitten.test.com", True), ("cats.dogs.test.com", False), @@ -492,7 +492,7 @@ def test_case_insensitive_query_mixed(self, name, matches): ("Cat.TEST.coms", False), ("dog.TEST.com", False), ("dog.Cat.TEST.com", False), - ), + ], ) def test_case_insensitive_expansion_mixed(self, name, matches): rule = WildcardStringRule("Cat.**.TEST.com", ["A"], DUMMY_FUNCTION, False) @@ -501,7 +501,7 @@ def test_case_insensitive_expansion_mixed(self, name, matches): @pytest.mark.parametrize( "name,matches", - ( + [ ("cat.kitten.test.com", False), ("cat.lion.kitten.test.com", False), ("cats.dogs.test.com", False), @@ -519,7 +519,7 @@ def test_case_insensitive_expansion_mixed(self, name, matches): ("Cat.TEST.coms", False), ("dog.TEST.com", False), ("dog.Cat.TEST.com", False), - ), + ], ) def test_case_sensitive(self, name, matches): rule = WildcardStringRule("Cat.**.TEST.com", ["A"], DUMMY_FUNCTION, True)