Skip to content

Commit

Permalink
working RunTests script
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Jan 8, 2024
1 parent 4f955f5 commit 42131cb
Show file tree
Hide file tree
Showing 14 changed files with 259 additions and 61 deletions.
21 changes: 19 additions & 2 deletions validity/compliance/eval/default_nameset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jq as pyjq

from validity.utils.config import config # noqa
from validity.models import VDevice


builtins = [
Expand All @@ -14,6 +14,7 @@
"bool",
"bytes",
"callable",
"classmethod",
"chr",
"complex",
"dict",
Expand All @@ -36,20 +37,22 @@
"oct",
"ord",
"pow",
"property",
"range",
"reversed",
"round",
"set",
"slice",
"sorted",
"staticmethod",
"str",
"sum",
"tuple",
"zip",
]


__all__ = ["jq", "config"] + builtins
__all__ = ["jq", "config", "state"] + builtins


class jq:
Expand All @@ -58,3 +61,17 @@ class jq:

def __init__(self, *args, **kwargs) -> None:
raise TypeError("jq is not callable")


def state(device):
# state() implies presence of "_data_source" and "_poller" global variables
# which are gonna be set by RunTests script
vdevice = VDevice()
vdevice.__dict__ = device.__dict__.copy()
vdevice.data_source = _data_source # noqa
vdevice._poller = _poller # noqa
return vdevice.state


def config(device):
return state(device).config
9 changes: 6 additions & 3 deletions validity/compliance/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ class NoComponentError(SerializationError):
Indicates lack of the required component (e.g. serializer) to do serialization
"""

def __init__(self, missing_component: str, orig_error: Exception | None = None) -> None:
def __init__(self, missing_component: str, parent: str | None = None) -> None:
self.missing_component = missing_component
super().__init__(orig_error)
self.parent = parent

def __str__(self) -> str:
return f"There is no bound {self.missing_component}"
result = f"There is no bound {self.missing_component}"
if self.parent:
result += f' for "{self.parent}"'
return result


class BadDataFileContentsError(SerializationError):
Expand Down
10 changes: 9 additions & 1 deletion validity/compliance/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from validity.compliance.serialization import Serializable
from ..utils.misc import reraise
from .exceptions import SerializationError, StateKeyError
from .exceptions import NoComponentError, SerializationError, StateKeyError


if TYPE_CHECKING:
Expand Down Expand Up @@ -41,6 +41,14 @@ def error(self) -> SerializationError | None:
except SerializationError as exc:
return exc

@property
def serialized(self):
try:
return super().serialized
except NoComponentError as exc:
exc.parent = self.name
raise


class State(dict):
def __init__(self, items, config_command_label: str | None = None):
Expand Down
9 changes: 8 additions & 1 deletion validity/forms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import json
from typing import Sequence
from typing import Any, Sequence

from django.forms import ChoiceField, Select
from utilities.forms import get_field_value


class IntegerChoiceField(ChoiceField):
def to_python(self, value: Any | None) -> Any | None:
if value is not None:
value = int(value)
return value


class SelectWithPlaceholder(Select):
def __init__(self, attrs=None, choices=()) -> None:
super().__init__(attrs, choices)
Expand Down
3 changes: 3 additions & 0 deletions validity/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class VDeviceQS(CustomPrefetchMixin, SetAttributesMixin, RestrictedQuerySet):
def set_selector(self, selector):
return self.set_attribute("selector", selector)

def set_datasource(self, data_source):
return self.set_attribute("data_source", data_source)

def annotate_datasource_id(self):
from validity.models import VDataSource

Expand Down
5 changes: 5 additions & 0 deletions validity/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ def new_data_file(path):
new_datafiles = (new_data_file(path) for path in paths)
created = len(DataFile.objects.bulk_create(new_datafiles, batch_size=batch_size))
logger.debug("%s new files were created and %s existing files were updated during sync", created, updated)

def sync(self, device_filter: Q | None = None):
if device_filter is not None and self.type == "device_polling":
return self.partial_sync(device_filter)
return super().sync()
6 changes: 5 additions & 1 deletion validity/models/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,8 @@ def dynamic_pair(self) -> Optional["VDevice"]:
filter_ = self.selector.dynamic_pair_filter(self)
if filter_ is None:
return
return type(self).objects.filter(filter_).first()
pair = type(self).objects.filter(filter_).first()
if pair:
pair.data_source = self.data_source
pair.poller = self.poller
return pair
3 changes: 1 addition & 2 deletions validity/models/nameset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ast
import builtins
from functools import cached_property
from inspect import getmembers
from typing import Any, Callable

Expand Down Expand Up @@ -56,7 +55,7 @@ def clean(self):
def effective_definitions(self):
return self.effective_text_field()

@cached_property
@property
def _globals(self):
return dict(getmembers(builtins)) | {name: getattr(default_nameset, name) for name in default_nameset.__all__}

Expand Down
14 changes: 14 additions & 0 deletions validity/models/test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ast
from functools import partial
from typing import Any, Callable

from django.core.exceptions import ValidationError
from django.db import models
from django.utils.translation import gettext_lazy as _

from validity.choices import SeverityChoices
from validity.compliance.eval import ExplanationalEval
from validity.managers import ComplianceTestQS
from .base import BaseModel, DataSourceMixin

Expand All @@ -20,6 +23,7 @@ class ComplianceTest(DataSourceMixin, BaseModel):

clone_fields = ("expression", "selectors", "severity", "data_source", "data_file")
text_db_field_name = "expression"
evaluator_cls = partial(ExplanationalEval, load_defaults=True)

objects = ComplianceTestQS.as_manager()

Expand All @@ -46,3 +50,13 @@ def get_severity_color(self):
@property
def effective_expression(self):
return self.effective_text_field()

def run(
self, device, functions: dict[str, Callable], extra_names: dict[str, Any] | None = None, verbosity: int = 2
) -> tuple[bool, list]:
names = {"device": device, "_poller": device.poller, "_data_source": device.data_source}
if extra_names:
names |= extra_names
evaluator = self.evaluator_cls(names=names, functions=functions, verbosity=verbosity)
passed = bool(evaluator.eval(self.effective_expression))
return passed, evaluator.explanation
98 changes: 56 additions & 42 deletions validity/scripts/run_tests.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import operator
import time
from functools import reduce
from itertools import chain
from typing import Any, Callable, Generator, Iterable

import yaml
from core.models import DataSource
from dcim.models import Device
from django.db.models import Prefetch, QuerySet
from django.db.models import Prefetch, Q, QuerySet
from django.utils.translation import gettext as __
from extras.choices import ObjectChangeActionChoices
from extras.models import Tag
from extras.scripts import BooleanVar, ChoiceVar, MultiObjectVar
from extras.scripts import BooleanVar, MultiObjectVar, ObjectVar
from extras.webhooks import enqueue_object
from netbox.context import webhooks_queue

import validity
from validity.choices import ExplanationVerbosityChoices
from validity.compliance.eval import ExplanationalEval
from validity.compliance.exceptions import EvalError, SerializationError
from validity.models import (
ComplianceReport,
Expand All @@ -26,23 +28,19 @@
VDevice,
)
from validity.utils.misc import datasource_sync, null_request
from .script_data import RunTestsScriptData, ScriptDataMixin
from .variables import VerbosityVar


class RequiredChoiceVar(ChoiceVar):
def __init__(self, choices, *args, **kwargs):
super().__init__(choices, *args, **kwargs)
self.field_attrs["choices"] = choices


class RunTestsScript:
class RunTestsScript(ScriptDataMixin[RunTestsScriptData]):
_sleep_between_tests = validity.settings.sleep_between_tests
_result_batch_size = validity.settings.result_batch_size

sync_datasources = BooleanVar(
required=False,
default=False,
label=__("Sync Data Sources"),
description=__('Sync all Data Source instances which have "device_config_path" defined'),
description=__("Sync all referenced Data Sources"),
)
make_report = BooleanVar(default=True, label=__("Make Compliance Report"))
selectors = MultiObjectVar(
Expand All @@ -63,12 +61,18 @@ class RunTestsScript:
label=__("Specific Test Tags"),
description=__("Run the tests which contain specific tags only"),
)
explanation_verbosity = RequiredChoiceVar(
explanation_verbosity = VerbosityVar(
choices=ExplanationVerbosityChoices.choices,
default=ExplanationVerbosityChoices.maximum,
label=__("Explanation Verbosity Level"),
required=False,
)
override_datasource = ObjectVar(
model=DataSource,
required=False,
label=__("Override DataSource"),
description=__("Find all devices state/config data in this Data Source instead of bound ones"),
)

class Meta:
name = __("Run Compliance Tests")
Expand All @@ -78,7 +82,6 @@ def __init__(self):
super().__init__()
self._nameset_functions = {}
self.global_namesets = NameSet.objects.filter(_global=True)
self.verbosity = 2
self.results_count = 0
self.results_passed = 0

Expand All @@ -97,11 +100,7 @@ def nameset_functions(self, namesets: Iterable[NameSet]) -> dict[str, Callable]:

def run_test(self, device: VDevice, test: ComplianceTest) -> tuple[bool, list[tuple[Any, Any]]]:
functions = self.nameset_functions(test.namesets.all())
evaluator = ExplanationalEval(
functions=functions, names={"device": device}, load_defaults=True, verbosity=self.verbosity
)
passed = bool(evaluator.eval(test.effective_expression))
return passed, evaluator.explanation
return test.run(device, functions, verbosity=self.script_data.explanation_verbosity)

def run_tests_for_device(
self,
Expand All @@ -112,6 +111,7 @@ def run_tests_for_device(
for test in tests_qs:
explanation = []
try:
device.state
passed, explanation = self.run_test(device, test)
except EvalError as exc:
self.log_failure(f"Failed to execute test **{test}** for device **{device}**, `{exc}`")
Expand All @@ -129,16 +129,20 @@ def run_tests_for_device(
)
time.sleep(self._sleep_between_tests)

def get_device_qs(self, selector: ComplianceSelector) -> QuerySet[VDevice]:
device_qs = selector.devices.select_related().prefetch_serializer().prefetch_poller()
if self.script_data.override_datasource:
device_qs = device_qs.set_datasource(self.script_data.override_datasource.obj)
else:
device_qs = device_qs.prefetch_datasource()
if self.script_data.devices:
device_qs = device_qs.filter(pk__in=self.script_data.devices)
return device_qs

def run_tests_for_selector(
self,
selector: ComplianceSelector,
report: ComplianceReport | None,
device_ids: list[int],
self, selector: ComplianceSelector, report: ComplianceReport | None
) -> Generator[ComplianceTestResult, None, None]:
qs = selector.devices.select_related().prefetch_datasource().prefetch_serializer().prefetch_poller()
if device_ids:
qs = qs.filter(pk__in=device_ids)
for device in qs:
for device in self.get_device_qs(selector):
try:
yield from self.run_tests_for_device(selector.tests.all(), device, report)
except SerializationError as e:
Expand All @@ -156,27 +160,37 @@ def save_to_db(self, results: Iterable[ComplianceTestResult], report: Compliance
if report:
ComplianceReport.objects.delete_old()

def get_selectors(self, data: dict) -> QuerySet[ComplianceSelector]:
selectors = ComplianceSelector.objects.all()
if specific_selectors := data.get("selectors"):
selectors = selectors.filter(pk__in=specific_selectors)
def get_selectors(self) -> QuerySet[ComplianceSelector]:
selectors = self.script_data.selectors.queryset
test_qs = ComplianceTest.objects.all()
if test_tags := data.get("test_tags"):
test_qs = test_qs.filter(tags__pk__in=test_tags).distinct()
selectors = selectors.filter(tests__tags__pk__in=test_tags).distinct()
if self.script_data.test_tags:
test_qs = test_qs.filter(tags__pk__in=self.script_data.test_tags).distinct()
selectors = selectors.filter(tests__tags__pk__in=self.script_data.test_tags).distinct()
return selectors.prefetch_related(Prefetch("tests", test_qs.prefetch_related("namesets")))

def perform_datasource_sync(self) -> None:
device_filter = reduce(operator.or_, (selector.filter for selector in self.script_data.selectors.queryset))
if self.script_data.devices:
device_filter |= Q(pk__in=self.script_data.devices)
if self.script_data.override_datasource:
self.script_data.override_datasource.obj.sync(device_filter)
return
datasource_ids = (
VDevice.objects.filter(device_filter)
.annotate_datasource_id()
.values_list("data_source_id", flat=True)
.distinct()
)
datasource_sync(VDataSource.objects.filter(pk__in=datasource_ids))

def run(self, data, commit):
self.verbosity = int(data.get("explanation_verbosity", self.verbosity))
if data.get("sync_datasources"):
datasource_sync(VDataSource.objects.exclude(custom_field_data__device_config_path=None))
self.script_data = self.script_data_cls(data)
selectors = self.get_selectors()
if self.script_data.sync_datasources:
self.perform_datasource_sync()
with null_request():
report = ComplianceReport.objects.create() if data.get("make_report") else None
selectors = self.get_selectors(data)
device_ids = data.get("devices", [])
results = chain.from_iterable(
self.run_tests_for_selector(selector, report, device_ids) for selector in selectors
)
report = ComplianceReport.objects.create() if self.script_data.make_report else None
results = chain.from_iterable(self.run_tests_for_selector(selector, report) for selector in selectors)
self.save_to_db(results, report)
output = {"results": {"all": self.results_count, "passed": self.results_passed}}
if report:
Expand Down
Loading

0 comments on commit 42131cb

Please sign in to comment.