diff --git a/tests/test_models.py b/tests/test_models.py index bd20f3c..6f3bce1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,9 +2,7 @@ import pytest import requests_mock -from ape_pie import APIClient -from zgw_consumers.client import ServiceConfigAdapter from zgw_consumers.constants import APITypes, AuthTypes from zgw_consumers.models import Service from zgw_consumers.test.factories import ServiceFactory @@ -83,42 +81,40 @@ def test_model_validation_with_oas_fields_disabled_both_provided(settings): @pytest.mark.django_db -def test_health_check_indication_service_model_badly_configured(settings): +def test_connection_check_service_model_badly_configured(settings): + settings.ZGW_CONSUMERS_IGNORE_OAS_FIELDS = True service = ServiceFactory.create( api_root="https://example.com/", - api_health_check_endpoint="foo", + api_connection_check_path="foo", auth_type=AuthTypes.zgw, client_id="my-client-id", secret="my-secret", ) - adapter = ServiceConfigAdapter(service) - client = APIClient.configure_from(adapter) - with requests_mock.Mocker() as m, client: + with requests_mock.Mocker() as m: m.get( "https://example.com/foo", status_code=404, ) service.refresh_from_db() - assert service.get_health_check_indication == False + assert service.connection_check == 404 @pytest.mark.django_db -def test_health_check_indication_service_model_correctly_configured(settings): +def test_connection_check_service_model_correctly_configured(settings): + settings.ZGW_CONSUMERS_IGNORE_OAS_FIELDS = True service = ServiceFactory.create( api_root="https://example.com/", - api_health_check_endpoint="foo", + api_connection_check_path="foo", auth_type=AuthTypes.zgw, client_id="my-client-id", secret="my-secret", ) - adapter = ServiceConfigAdapter(service) - client = APIClient.configure_from(adapter) - with requests_mock.Mocker() as m, client: + with requests_mock.Mocker() as m: m.get( "https://example.com/foo", status_code=200, ) service.refresh_from_db() - assert service.get_health_check_indication == True + assert service.connection_check == 200 diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..ec07110 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,38 @@ +from django.core.exceptions import ValidationError + +import pytest + +from zgw_consumers.models.validators import IsNotUrlValidator, StartWithValidator + + +def test_start_with_validator_return_value_false(): + validator = StartWithValidator(prefix="/", return_value=False) + + assert validator.__call__("no_leading_slash") is None + + with pytest.raises(ValidationError) as exc_context: + validator.__call__("/with_leading_slash") + + assert "The given value cannot start with '/'" in exc_context.value + + +def test_start_with_validator_return_value_true(): + validator = StartWithValidator(prefix="/") + + with pytest.raises(ValidationError) as exc_context: + validator.__call__("no_leading_slash") + + assert "The given value must start with '/'" in exc_context.value + + assert validator.__call__("/with_leading_slash") is None + + +def test_is_not_url_validator(): + validator = IsNotUrlValidator() + + assert validator.__call__("some random text") is None + + with pytest.raises(ValidationError) as exc_context: + assert validator.__call__("http://www.example.com") + + assert "String cannot be a URL" in exc_context.value diff --git a/zgw_consumers/admin.py b/zgw_consumers/admin.py index 93e3827..8c0f194 100644 --- a/zgw_consumers/admin.py +++ b/zgw_consumers/admin.py @@ -13,22 +13,14 @@ @admin.register(Service) class ServiceAdmin(admin.ModelAdmin): - list_display = ( - "label", - "api_type", - "api_root", - "nlx", - "auth_type", - ) + list_display = ("label", "api_type", "api_root", "nlx", "auth_type") list_filter = ("api_type", "auth_type") search_fields = ("label", "api_root", "nlx", "uuid") - readonly_fields = [ - "get_health_check_indication", - ] + readonly_fields = ("get_connection_check",) - @admin.display(description="Health Check", boolean=True) - def get_health_check_indication(self, obj): - return obj.get_health_check_indication + @admin.display(description="Connection Check") + def get_connection_check(self, obj): + return obj.connection_check def get_fields(self, request: HttpRequest, obj: models.Model | None = None): fields = super().get_fields(request, obj=obj) diff --git a/zgw_consumers/migrations/0021_service_api_connection_check_path.py b/zgw_consumers/migrations/0021_service_api_connection_check_path.py new file mode 100644 index 0000000..104acbb --- /dev/null +++ b/zgw_consumers/migrations/0021_service_api_connection_check_path.py @@ -0,0 +1,31 @@ +# Generated by Django 4.2 on 2024-05-21 08:15 + +from django.db import migrations, models + +import zgw_consumers.models.validators + + +class Migration(migrations.Migration): + + dependencies = [ + ("zgw_consumers", "0020_service_timeout"), + ] + + operations = [ + migrations.AddField( + model_name="service", + name="api_connection_check_path", + field=models.CharField( + blank=True, + help_text="An optional API endpoint which will be used to check if the API is configured correctly and is currently up or down. This field is only used for in the admin's 'Connection check' field.", + max_length=255, + validators=[ + zgw_consumers.models.validators.StartWithValidator( + prefix="/", return_value=False + ), + zgw_consumers.models.validators.IsNotUrlValidator(), + ], + verbose_name="connection check endpoint", + ), + ), + ] diff --git a/zgw_consumers/migrations/0021_service_api_health_check_endpoint.py b/zgw_consumers/migrations/0021_service_api_health_check_endpoint.py deleted file mode 100644 index 3b9126d..0000000 --- a/zgw_consumers/migrations/0021_service_api_health_check_endpoint.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 3.2 on 2024-05-10 08:58 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("zgw_consumers", "0020_service_timeout"), - ] - - operations = [ - migrations.AddField( - model_name="service", - name="api_health_check_endpoint", - field=models.CharField( - help_text="An optional API endpoint which will be used to check if the API is configured correctly and is currently up or down. This field is only used for in the admin's 'health check' field.", - max_length=255, - blank=True, - verbose_name="health check endpoint", - ), - ), - ] diff --git a/zgw_consumers/models/services.py b/zgw_consumers/models/services.py index 89ef162..98b3fba 100644 --- a/zgw_consumers/models/services.py +++ b/zgw_consumers/models/services.py @@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _ from privates.fields import PrivateMediaFileField -from requests.exceptions import ConnectionError, RequestException +from requests.exceptions import RequestException from simple_certmanager.models import Certificate from solo.models import SingletonModel from typing_extensions import Self, deprecated @@ -21,6 +21,7 @@ from ..constants import APITypes, AuthTypes, NLXDirectories from .abstract import RestAPIService +from .validators import IsNotUrlValidator, StartWithValidator logger = logging.getLogger(__name__) @@ -32,13 +33,17 @@ class Service(RestAPIService): uuid = models.UUIDField(_("UUID"), default=uuid.uuid4) api_type = models.CharField(_("type"), max_length=20, choices=APITypes.choices) api_root = models.CharField(_("api root url"), max_length=255, unique=True) - api_health_check_endpoint = models.CharField( - _("health check endpoint"), + api_connection_check_path = models.CharField( + _("connection check endpoint"), help_text=_( "An optional API endpoint which will be used to check if the API is configured correctly and " - "is currently up or down. This field is only used for in the admin's 'health check' field." + "is currently up or down. This field is only used for in the admin's 'Connection check' field." ), max_length=255, + validators=[ + StartWithValidator(prefix="/", return_value=False), + IsNotUrlValidator(), + ], blank=True, ) @@ -127,20 +132,22 @@ def clean(self): ) @property - def get_health_check_indication(self) -> bool: + def connection_check(self) -> bool: from zgw_consumers.client import build_client try: client = build_client(self) - if ( - client.get(self.api_health_check_endpoint or self.api_root).status_code - == 200 - ): - return True - except (ConnectionError, RequestException) as e: - logger.exception(self, exc_info=e) - - return False + return client.get( + self.api_connection_check_path or self.api_root + ).status_code + except RequestException as e: + logger.info( + "Encountered an error while performing the connection check to service %s", + self, + exc_info=e, + ) + + return None @deprecated( "The `build_client` method is deprecated and will be removed in the next major release. " diff --git a/zgw_consumers/models/validators.py b/zgw_consumers/models/validators.py new file mode 100644 index 0000000..02685b1 --- /dev/null +++ b/zgw_consumers/models/validators.py @@ -0,0 +1,58 @@ +from django.core.exceptions import ValidationError +from django.core.validators import URLValidator +from django.utils.deconstruct import deconstructible +from django.utils.translation import gettext_lazy as _ + + +@deconstructible +class StartWithValidator: + code = "invalid" + + def __init__( + self, + prefix: str, + message: str = None, + code: str = None, + return_value: bool = True, + ): + self.prefix = prefix + self.return_value = return_value + + if code is not None: + self.code = code + + if message is not None: + self.message = message + else: + self.message = _( + "The given value {must_or_cannot} start with '{prefix}'" + ).format( + must_or_cannot="must" if self.return_value else "cannot", + prefix=self.prefix, + ) + + def __call__(self, value: str) -> bool: + if not value.startswith(self.prefix) == self.return_value: + raise ValidationError(self.message, code=self.code, params={"value": value}) + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.prefix == other.prefix + and (self.message == other.message) + and (self.code == other.code) + and (self.return_value == other.return_value) + ) + + +@deconstructible +class IsNotUrlValidator(URLValidator): + message = _("String cannot be a URL") + + def __call__(self, value): + try: + super().__call__(value) + except ValidationError: + return + + raise ValidationError(self.message, code=self.code, params={"value": value})