Skip to content

Commit

Permalink
[rules] Add Static and Zone rules, refactor rule decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
nhairs committed Dec 18, 2023
1 parent 76531a5 commit e8fb9ec
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 65 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

- Implement [Middleware][middleware]
- This includes adding error handling middleware that facilitates [error handling][error-handling].
- Add [`StaticRule`][nserver.rules.StaticRule] and [`ZoneRule`][nserver.rules.ZoneRule].
- Refector [`NameServer.rule`][nserver.server.NameServer.rule] to use expanded [`smart_make_rule`][nserver.rules.smart_make_rule] function.
- Although this change should not affect rules using this decorator from being called correctly, it may change the precise rule type being used. Specifically it may use `StaticRule` instead of `WildcardStringRule` for strings with no substitutions.

## 1.0.0

Expand Down
10 changes: 8 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,20 @@ example.com. 300 IN A 1.2.3.4

## Rules

[Rules][nserver.rules] tell our server which queries to send to which functions. NServer ships with two rule types:
[Rules][nserver.rules] tell our server which queries to send to which functions. NServer ships with a number of rule types.

- [`StaticRule`][nserver.rules.StaticRule] matches on an exact string.
- [`ZoneRule`][nserver.rules.ZoneRule] matches the given domain and all subdomains.
- [`WildcardStringRule`][nserver.rules.WildcardStringRule] which allows writing rules using a shorthand syntax.
- [`RegexRule`][nserver.rules.RegexRule] which uses regular expressions for matching.

When using the [`NameServer.rule`][nserver.server.NameServer.rule] decorator string (`str`) rules will be used to create a `WildcardStringRule` whilst regular expression (`typing.Pattern`) rules will create a `RegexRule`. This decorator also return the original function unchanged meaning it is possible to decorate the same function with multiple rules.
The [`NameServer.rule`][nserver.server.NameServer.rule] decorator uses [`smart_make_rule`][nserver.rules.smart_make_rule] to automatically select the "best" matching rule type based on the input. This will result in string (`str`) rules will be used to create either a `WildcardStringRule` or a `StaticRule`, whilst regular expression (`typing.Pattern`) rules will create a `RegexRule`. This decorator also return the original function unchanged meaning it is possible to decorate the same function with multiple rules.

```python
@saerver.rule("google-dns", ["A"])
def this_will_be_a_static_rule(query):
return A(query.name, "8.8.8.8")

@server.rule("{base_name}", ["A"])
@server.rule("www.{base_name}", ["A"])
@server.rule("mail.{base_name}", ["A"])
Expand Down
2 changes: 1 addition & 1 deletion src/nserver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .models import Query, Response
from .rules import RegexRule, WildcardStringRule
from .rules import StaticRule, ZoneRule, RegexRule, WildcardStringRule
from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA
from .server import NameServer
from .settings import Settings
153 changes: 150 additions & 3 deletions src/nserver/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,62 @@
### ============================================================================
## Standard Library
import re
from typing import Callable, List, Optional, Pattern, Union
from typing import Callable, List, Optional, Pattern, Union, Type

## Installed
import dnslib
import tldextract

## Application
from .models import Query, Response
from .records import RecordBase

### CONSTANTS
### ============================================================================
ALL_QTYPES: List[str] = list(dnslib.QTYPE.reverse.keys())
"""All supported Query Types
New in `1.1.0`.
"""

_wildcard_string_regex = re.compile(r"[*]|\{base_domain\}")


### FUNCTIONS
### ============================================================================
def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs) -> "RuleBase":
"""Create a rule using shorthand notation.
The exact type of rule returned depends on what is povided by `rule`.
If rule is a
- `RuleBase` class, then it is used directly.
- `str` then it is checked to see if it contains substitutions. If it does then
it will be a `WildcardStringRule`, else a `StaticRule`.
- `Pattern` then a `RegexRule`.
New in `1.1.0`
Args:
rule: input to process
args: extra arguments to provide to the constructor
kwargs: extra keyword arguments to provide to the constructor
"""
if isinstance(rule, str):
if _wildcard_string_regex.search(rule):
return WildcardStringRule(rule, *args, **kwargs)
return StaticRule(rule, *args, **kwargs)

# pylint: disable=isinstance-second-argument-not-valid-type
if isinstance(rule, Pattern):
# Note: I've disabled this type check thing as it currently works and it might
# vary between versions of python and other bugs.
# see also: https://stackoverflow.com/questions/6102019/type-of-compiled-regex-object-in-python
return RegexRule(rule, *args, **kwargs)
raise ValueError(f"Could not handle rule: {rule!r}")


### CLASSES
### ============================================================================
RuleResult = Union[Response, RecordBase, List[RecordBase], None]
Expand Down Expand Up @@ -41,6 +88,106 @@ def get_func(self, query: Query) -> Optional[ResponseFunction]:
raise NotImplementedError()


class StaticRule(RuleBase):
"""Rule that matches only the given string
`StaticRule` is more efficient than using a `WildcardStringRule` for static strings.
New in `1.1.0`.
"""

def __init__(
self,
match_string: str,
allowed_qtypes: List[str],
func: ResponseFunction,
case_sensitive: bool = False,
) -> None:
"""
Args:
match_string: string to match
allowed_qtypes: match only the given query types
func: response function to call
case_sensitive: how to case when matching
"""
self.match_string = match_string if case_sensitive else match_string.lower()
self.allowed_qtypes = set(allowed_qtypes)
self.func = func
self.case_sensitive = case_sensitive
return

def get_func(self, query: Query) -> Optional[ResponseFunction]:
"""Same as parent class"""
if query.type not in self.allowed_qtypes:
return None

check_string = query.name
if not self.case_sensitive:
check_string = check_string.lower()

if check_string == self.match_string:
return self.func
return None

def __repr__(self):
return f"{self.__class__.__name__}(match_string={self.match_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})"

def __str__(self):
return f"{self.__class__.__name__}({self.match_string!r}, {self.allowed_qtypes!r})"


class ZoneRule(RuleBase):
"""Rule that matches the given domain or any subdomain
An empty zone (`""`) will match any domain as this refers to the domain root (`.`).
New in `1.1.0`.
"""

def __init__(
self,
zone: str,
allowed_qtypes: List[str],
func: ResponseFunction,
case_sensitive: bool = False,
) -> None:
"""
Args:
zone: zone root
allowed_qtypes: match only the given query types.
func: response function to call
case_sensitive: how to case when matching
"""
zone = zone.strip(".")
self.zone = zone if case_sensitive else zone.lower()
self.allowed_qtypes = set(allowed_qtypes) if allowed_qtypes else None
self.func = func
self.case_sensitive = case_sensitive
return

def get_func(self, query: Query) -> Optional[ResponseFunction]:
"""Same as parent class"""
if self.allowed_qtypes is not None and query.type not in self.allowed_qtypes:
return None

if self.zone == "":
return self.func

check_string = query.name
if not self.case_sensitive:
check_string = check_string.lower()

if check_string == self.zone or check_string.endswith(f".{self.zone}"):
return self.func
return None

def __repr__(self):
return f"{self.__class__.__name__}(zone={self.zone!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})"

def __str__(self):
return f"{self.__class__.__name__}({self.zone!r}, {self.allowed_qtypes!r})"


class RegexRule(RuleBase):
"""Rule that uses the provided regex to attempt to match the query name."""

Expand Down Expand Up @@ -82,7 +229,7 @@ def get_func(self, query: Query) -> Optional[ResponseFunction]:
return None

def __repr__(self):
return f"{self.__class__.__name__}(regex={self.regex.pattern!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r})"
return f"{self.__class__.__name__}(regex={self.regex.pattern!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})"

def __str__(self):
return f"{self.__class__.__name__}({self.regex.pattern!r}, {self.allowed_qtypes!r})"
Expand Down Expand Up @@ -175,7 +322,7 @@ def _get_regex(self, query_domain: str) -> Pattern:
return regex

def __repr__(self):
return f"{self.__class__.__name__}(wildcard_string={self.wildcard_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r})"
return f"{self.__class__.__name__}(wildcard_string={self.wildcard_string!r}, allowed_qtypes={self.allowed_qtypes!r}, func={self.func!r}, case_sensitive={self.case_sensitive!r})"

def __str__(self):
return f"{self.__class__.__name__}({self.wildcard_string!r}, {self.allowed_qtypes!r})"
47 changes: 20 additions & 27 deletions src/nserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

## Application
from .exceptions import InvalidMessageError
from .rules import RuleBase, WildcardStringRule, RegexRule, ResponseFunction
from .rules import smart_make_rule, RuleBase, ResponseFunction
from .settings import Settings
from .transport import TransportBase, UDPv4Transport, UDPv6Transport, TCPv4Transport

Expand Down Expand Up @@ -263,38 +263,31 @@ def run(self) -> int:

## Decorators
## -------------------------------------------------------------------------
def rule(
self, rule_: Union[str, Pattern], allowed_qtypes: List[str], case_sensitive: bool = False
): # pylint: disable=unused-argument
"""Decorator for registering a function as a rule.
def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs):
"""Decorator for registering a function using an appropriate rule.
Changed in `1.1.0`: This function now uses [`smart_make_rule`][nserver.rules.smart_make_rule].
At the time of writing this allows for Rule classes be used directly,
and `str` inputs may result in a `StaticRule`. Future changes to
`smart_make_rule` will not be documented here.
Args:
rule_: if `Pattern` then `RegexRule`, if `str` then `WildcardStringRule`.
allowed_qtypes: Only match the given DNS query types
case_sensitive: how to handle case when matching the rule
rule_: rule as per `nserver.rules.smart_make_rule`
args: extra arguments to provide
kwargs: extra keyword arguments to provide
Raises:
ValueError: if `func` is provided in `kwargs`.
"""

if "func" in kwargs:
raise ValueError("Must not provide `func` in kwargs")

def decorator(func: ResponseFunction):
nonlocal rule_
nonlocal allowed_qtypes
nonlocal case_sensitive
actual_rule: RuleBase

if isinstance(rule_, str):
actual_rule = WildcardStringRule(
rule_, allowed_qtypes, func, case_sensitive=case_sensitive
)
elif isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
rule_, Pattern
):
# Note: I've disabled this type check thing as it currently works and it might
# vary between versions of python and other bugs.
# see also: https://stackoverflow.com/questions/6102019/type-of-compiled-regex-object-in-python
actual_rule = RegexRule(rule_, allowed_qtypes, func, case_sensitive=case_sensitive)
else:
raise ValueError(f"Could not handle rule: {rule_!r}")

self.register_rule(actual_rule)
nonlocal args
nonlocal kwargs
self.register_rule(smart_make_rule(rule_, *args, func=func, **kwargs))
return func

return decorator
Expand Down
Loading

0 comments on commit e8fb9ec

Please sign in to comment.