diff --git a/docs/changelog.md b/docs/changelog.md index 68f5f3b..fd35d2d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,9 @@ - Implement [Middleware][middleware] - This includes adding error handling middleware that facilitates [error handling][error-handling]. +- Add [`StaticRule`][nserver.rules.StaticRule] and [`ZoneRule`][nserver.rules.ZoneRule]. +- Refector [`NameServer.rule`][nserver.server.NameServer.rule] to use expanded [`smart_make_rule`][nserver.rules.smart_make_rule] function. + - Although this change should not affect rules using this decorator from being called correctly, it may change the precise rule type being used. Specifically it may use `StaticRule` instead of `WildcardStringRule` for strings with no substitutions. ## 1.0.0 diff --git a/docs/quickstart.md b/docs/quickstart.md index 05e9abe..25d25fc 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -79,14 +79,20 @@ example.com. 300 IN A 1.2.3.4 ## Rules -[Rules][nserver.rules] tell our server which queries to send to which functions. NServer ships with two rule types: +[Rules][nserver.rules] tell our server which queries to send to which functions. NServer ships with a number of rule types. +- [`StaticRule`][nserver.rules.StaticRule] matches on an exact string. +- [`ZoneRule`][nserver.rules.ZoneRule] matches the given domain and all subdomains. - [`WildcardStringRule`][nserver.rules.WildcardStringRule] which allows writing rules using a shorthand syntax. - [`RegexRule`][nserver.rules.RegexRule] which uses regular expressions for matching. -When using the [`NameServer.rule`][nserver.server.NameServer.rule] decorator string (`str`) rules will be used to create a `WildcardStringRule` whilst regular expression (`typing.Pattern`) rules will create a `RegexRule`. This decorator also return the original function unchanged meaning it is possible to decorate the same function with multiple rules. +The [`NameServer.rule`][nserver.server.NameServer.rule] decorator uses [`smart_make_rule`][nserver.rules.smart_make_rule] to automatically select the "best" matching rule type based on the input. This will result in string (`str`) rules will be used to create either a `WildcardStringRule` or a `StaticRule`, whilst regular expression (`typing.Pattern`) rules will create a `RegexRule`. This decorator also return the original function unchanged meaning it is possible to decorate the same function with multiple rules. ```python +@saerver.rule("google-dns", ["A"]) +def this_will_be_a_static_rule(query): + return A(query.name, "8.8.8.8") + @server.rule("{base_name}", ["A"]) @server.rule("www.{base_name}", ["A"]) @server.rule("mail.{base_name}", ["A"]) diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index 0c83051..7844bde 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,5 +1,5 @@ from .models import Query, Response -from .rules import RegexRule, WildcardStringRule +from .rules import StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA from .server import NameServer from .settings import Settings diff --git a/src/nserver/rules.py b/src/nserver/rules.py index 0848f6a..4b521eb 100644 --- a/src/nserver/rules.py +++ b/src/nserver/rules.py @@ -4,15 +4,62 @@ ### ============================================================================ ## Standard Library import re -from typing import Callable, List, Optional, Pattern, Union +from typing import Callable, List, Optional, Pattern, Union, Type ## Installed +import dnslib import tldextract ## Application from .models import Query, Response from .records import RecordBase +### CONSTANTS +### ============================================================================ +ALL_QTYPES: List[str] = list(dnslib.QTYPE.reverse.keys()) +"""All supported Query Types + +New in `1.1.0`. +""" + +_wildcard_string_regex = re.compile(r"[*]|\{base_domain\}") + + +### FUNCTIONS +### ============================================================================ +def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs) -> "RuleBase": + """Create a rule using shorthand notation. + + The exact type of rule returned depends on what is povided by `rule`. + + If rule is a + + - `RuleBase` class, then it is used directly. + - `str` then it is checked to see if it contains substitutions. If it does then + it will be a `WildcardStringRule`, else a `StaticRule`. + - `Pattern` then a `RegexRule`. + + New in `1.1.0` + + Args: + rule: input to process + args: extra arguments to provide to the constructor + kwargs: extra keyword arguments to provide to the constructor + """ + if isinstance(rule, str): + if _wildcard_string_regex.search(rule): + return WildcardStringRule(rule, *args, **kwargs) + return StaticRule(rule, *args, **kwargs) + + # pylint: disable=isinstance-second-argument-not-valid-type + if isinstance(rule, Pattern): + # Note: I've disabled this type check thing as it currently works and it might + # vary between versions of python and other bugs. + # see also: https://stackoverflow.com/questions/6102019/type-of-compiled-regex-object-in-python + return RegexRule(rule, *args, **kwargs) + raise ValueError(f"Could not handle rule: {rule!r}") + + ### CLASSES ### ============================================================================ RuleResult = Union[Response, RecordBase, List[RecordBase], None] @@ -41,6 +88,106 @@ def get_func(self, query: Query) -> Optional[ResponseFunction]: raise NotImplementedError() +class StaticRule(RuleBase): + """Rule that matches only the given string + + `StaticRule` is more efficient than using a `WildcardStringRule` for static strings. + + New in `1.1.0`. + """ + + def __init__( + self, + match_string: str, + allowed_qtypes: List[str], + func: ResponseFunction, + case_sensitive: bool = False, + ) -> None: + """ + Args: + match_string: string to match + allowed_qtypes: match only the given query types + func: response function to call + case_sensitive: how to case when matching + """ + self.match_string = match_string if case_sensitive else match_string.lower() + self.allowed_qtypes = set(allowed_qtypes) + self.func = func + self.case_sensitive = case_sensitive + return + + def get_func(self, query: Query) -> Optional[ResponseFunction]: + """Same as parent class""" + if query.type not in self.allowed_qtypes: + return None + + check_string = query.name + if not self.case_sensitive: + check_string = check_string.lower() + + if check_string == self.match_string: + return self.func + return None + + def __repr__(self): + return f"{self.__class__.__name__}(match_string={self.match_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})" + + def __str__(self): + return f"{self.__class__.__name__}({self.match_string!r}, {self.allowed_qtypes!r})" + + +class ZoneRule(RuleBase): + """Rule that matches the given domain or any subdomain + + An empty zone (`""`) will match any domain as this refers to the domain root (`.`). + + New in `1.1.0`. + """ + + def __init__( + self, + zone: str, + allowed_qtypes: List[str], + func: ResponseFunction, + case_sensitive: bool = False, + ) -> None: + """ + Args: + zone: zone root + allowed_qtypes: match only the given query types. + func: response function to call + case_sensitive: how to case when matching + """ + zone = zone.strip(".") + self.zone = zone if case_sensitive else zone.lower() + self.allowed_qtypes = set(allowed_qtypes) if allowed_qtypes else None + self.func = func + self.case_sensitive = case_sensitive + return + + def get_func(self, query: Query) -> Optional[ResponseFunction]: + """Same as parent class""" + if self.allowed_qtypes is not None and query.type not in self.allowed_qtypes: + return None + + if self.zone == "": + return self.func + + check_string = query.name + if not self.case_sensitive: + check_string = check_string.lower() + + if check_string == self.zone or check_string.endswith(f".{self.zone}"): + return self.func + return None + + def __repr__(self): + return f"{self.__class__.__name__}(zone={self.zone!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})" + + def __str__(self): + return f"{self.__class__.__name__}({self.zone!r}, {self.allowed_qtypes!r})" + + class RegexRule(RuleBase): """Rule that uses the provided regex to attempt to match the query name.""" @@ -82,7 +229,7 @@ def get_func(self, query: Query) -> Optional[ResponseFunction]: return None def __repr__(self): - return f"{self.__class__.__name__}(regex={self.regex.pattern!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r})" + return f"{self.__class__.__name__}(regex={self.regex.pattern!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})" def __str__(self): return f"{self.__class__.__name__}({self.regex.pattern!r}, {self.allowed_qtypes!r})" @@ -175,7 +322,7 @@ def _get_regex(self, query_domain: str) -> Pattern: return regex def __repr__(self): - return f"{self.__class__.__name__}(wildcard_string={self.wildcard_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r})" + return f"{self.__class__.__name__}(wildcard_string={self.wildcard_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})" def __str__(self): return f"{self.__class__.__name__}({self.wildcard_string!r}, {self.allowed_qtypes!r})" diff --git a/src/nserver/server.py b/src/nserver/server.py index 69a6db5..ed9bc0c 100644 --- a/src/nserver/server.py +++ b/src/nserver/server.py @@ -11,7 +11,7 @@ ## Application from .exceptions import InvalidMessageError -from .rules import RuleBase, WildcardStringRule, RegexRule, ResponseFunction +from .rules import smart_make_rule, RuleBase, ResponseFunction from .settings import Settings from .transport import TransportBase, UDPv4Transport, UDPv6Transport, TCPv4Transport @@ -263,38 +263,31 @@ def run(self) -> int: ## Decorators ## ------------------------------------------------------------------------- - def rule( - self, rule_: Union[str, Pattern], allowed_qtypes: List[str], case_sensitive: bool = False - ): # pylint: disable=unused-argument - """Decorator for registering a function as a rule. + def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): + """Decorator for registering a function using an appropriate rule. + + Changed in `1.1.0`: This function now uses [`smart_make_rule`][nserver.rules.smart_make_rule]. + At the time of writing this allows for Rule classes be used directly, + and `str` inputs may result in a `StaticRule`. Future changes to + `smart_make_rule` will not be documented here. Args: - rule_: if `Pattern` then `RegexRule`, if `str` then `WildcardStringRule`. - allowed_qtypes: Only match the given DNS query types - case_sensitive: how to handle case when matching the rule + rule_: rule as per `nserver.rules.smart_make_rule` + args: extra arguments to provide + kwargs: extra keyword arguments to provide + + Raises: + ValueError: if `func` is provided in `kwargs`. """ + if "func" in kwargs: + raise ValueError("Must not provide `func` in kwargs") + def decorator(func: ResponseFunction): nonlocal rule_ - nonlocal allowed_qtypes - nonlocal case_sensitive - actual_rule: RuleBase - - if isinstance(rule_, str): - actual_rule = WildcardStringRule( - rule_, allowed_qtypes, func, case_sensitive=case_sensitive - ) - elif isinstance( # pylint: disable=isinstance-second-argument-not-valid-type - rule_, Pattern - ): - # Note: I've disabled this type check thing as it currently works and it might - # vary between versions of python and other bugs. - # see also: https://stackoverflow.com/questions/6102019/type-of-compiled-regex-object-in-python - actual_rule = RegexRule(rule_, allowed_qtypes, func, case_sensitive=case_sensitive) - else: - raise ValueError(f"Could not handle rule: {rule_!r}") - - self.register_rule(actual_rule) + nonlocal args + nonlocal kwargs + self.register_rule(smart_make_rule(rule_, *args, func=func, **kwargs)) return func return decorator diff --git a/tests/test_rules.py b/tests/test_rules.py index 2d391f1..f095d9f 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -8,7 +8,14 @@ ## Installed import pytest -from nserver.rules import RegexRule, WildcardStringRule +from nserver.rules import ( + _wildcard_string_regex, + smart_make_rule, + StaticRule, + ZoneRule, + RegexRule, + WildcardStringRule, +) from nserver.models import Query ## Application @@ -29,6 +36,158 @@ def run_rule(rule, query, matches): ### TESTS ### ============================================================================ +@pytest.mark.parametrize( + "rule,expected", + ( + ("*", True), + ("**", True), + ("{base_domain}", True), + ("*.com", True), + ("**.com", True), + ("", False), + ("foo.com", False), + ("{something}", False), + ), +) +def test_wildcard_string_regex(rule, expected): + assert bool(_wildcard_string_regex.search(rule)) is expected + + +@pytest.mark.parametrize( + "rule,expected", + ( + ("", StaticRule), + ("foo", StaticRule), + ("example.com", StaticRule), + ("foo-bar.com", StaticRule), + ("__dmarc.foo.com", StaticRule), + ("*.example.com", WildcardStringRule), + ("**.example.com", WildcardStringRule), + ("*.mail.{base_domain}", WildcardStringRule), + ("**.mail.{base_domain}", WildcardStringRule), + ("{base_domain}", WildcardStringRule), + ("*.{base_domain}", WildcardStringRule), + ("*.{base_domain}", WildcardStringRule), + (re.compile(".*"), RegexRule), + ), +) +def test_smart_make_rule_class(rule, expected): + assert isinstance(smart_make_rule(rule, ["A"], func=DUMMY_FUNCTION), expected) + + +## StaticRule +## ----------------------------------------------------------------------------- +class TestStaticRule: + @pytest.mark.parametrize( + "qtype,matches", + ( + ("A", True), + ("AAAA", True), + ("TXT", False), + ), + ) + def test_qtypes(self, qtype, matches): + rule = StaticRule("", ["A", "AAAA"], DUMMY_FUNCTION) + run_rule(rule, Query(qtype, ""), matches) + return + + @pytest.mark.parametrize("match_string", ("test.com", "TEST.com", "test.COM", "TeSt.CoM")) + @pytest.mark.parametrize( + "name,matches", + ( + ("test.com", True), + ("TEST.com", True), + ("test.COM", True), + ("TEST.COM", True), + ("TeSt.CoM", True), + ("", False), + ("com", False), + ("foo.test.com", False), + ), + ) + def test_case_insensitive(self, match_string, name, matches): + rule = StaticRule(match_string, ["A"], DUMMY_FUNCTION, False) + run_rule(rule, Query("A", name), matches) + return + + @pytest.mark.parametrize( + "name,matches", + ( + ("test.com", True), + ("TEST.com", False), + ("test.COM", False), + ("TEST.COM", False), + ("TeSt.CoM", False), + ("", False), + ("com", False), + ("foo.test.com", False), + ), + ) + def test_case_sensitive(self, name, matches): + rule = StaticRule("test.com", ["A"], DUMMY_FUNCTION, True) + run_rule(rule, Query("A", name), matches) + return + + +## ZoneRule +## ----------------------------------------------------------------------------- +class TestZoneRule: + @pytest.mark.parametrize( + "qtype,matches", + ( + ("A", True), + ("AAAA", True), + ("TXT", False), + ), + ) + def test_qtypes(self, qtype, matches): + rule = ZoneRule("", ["A", "AAAA"], DUMMY_FUNCTION) + run_rule(rule, Query(qtype, ""), matches) + return + + @pytest.mark.parametrize("zone", ("test.com", "TEST.com", "test.COM", "TeSt.CoM")) + @pytest.mark.parametrize( + "name,matches", + ( + ("test.com", True), + ("TEST.com", True), + ("test.COM", True), + ("TEST.COM", True), + ("TeSt.CoM", True), + ("foo.TEST.com", True), + ("BAR.FOO.test.COM", True), + ("CAR.bar.FOO.TEST.COM", True), + ("__dmarc.TeSt.CoM", True), + ("", False), + ("com", False), + ), + ) + def test_case_insensitive(self, zone, name, matches): + rule = ZoneRule(zone, ["A"], DUMMY_FUNCTION, False) + run_rule(rule, Query("A", name), matches) + return + + @pytest.mark.parametrize( + "name,matches", + ( + ("test.com", True), + ("foo.test.com", True), + ("FOO.test.com", True), + ("bar.foo.test.com", True), + ("BAR.foo.test.com", True), + ("car.bar.foo.test.com", True), + ("TEST.com", False), + ("test.COM", False), + ("TEST.COM", False), + ("TeSt.CoM", False), + ("", False), + ("com", False), + ), + ) + def test_case_sensitive(self, name, matches): + rule = ZoneRule("test.com", ["A"], DUMMY_FUNCTION, True) + run_rule(rule, Query("A", name), matches) + return ## RegexRule diff --git a/tests/test_server.py b/tests/test_server.py index 73fb1d6..f8561f4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,7 +3,6 @@ ### IMPORTS ### ============================================================================ ## Standard Library -import re from typing import List import unittest.mock @@ -11,7 +10,7 @@ import dnslib import pytest -from nserver import NameServer, Query, Response, A, RegexRule, WildcardStringRule +from nserver import NameServer, Query, Response, A ## Application @@ -28,16 +27,6 @@ def dummy_rule(query: Query) -> A: return A(query.name, IP) -@server.rule("wildcard-rule-expected.com", ["A"]) -def wildcard_rule_expected(query: Query) -> A: - return A(query.name, IP) - - -@server.rule(re.compile(r"regex-rule-expected\.com"), ["A"]) -def regex_rule_expected(query: Query) -> A: - return A(query.name, IP) - - @server.rule("none-response.com", ["A"]) def none_response(query: Query) -> None: # pylint: disable=unused-argument return None @@ -126,25 +115,6 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> ### TESTS ### ============================================================================ -## NameServer.rule -## ----------------------------------------------------------------------------- -def test_rule_decorator_type(): - wildcard_tested = False - regex_tested = False - - for rule in server.rules: - if rule.func is wildcard_rule_expected: - wildcard_tested = True - assert isinstance(rule, WildcardStringRule) - elif rule.func is regex_rule_expected: - regex_tested = True - assert isinstance(rule, RegexRule) - - # Check all tests run - assert all([wildcard_tested, regex_tested]) - return - - ## NameServer._process_dns_record ## ----------------------------------------------------------------------------- def test_none_response():