From a1335059cfd9249e2369480fc1fa52c53205f34c Mon Sep 17 00:00:00 2001 From: cicharka <93913624+cicharka@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:01:22 +0100 Subject: [PATCH] [Fix] using Literal, allow Literal as params, fix aliases (#471) --- .../api/config_device_inventory_api.py | 8 +- catalystwan/endpoints/__init__.py | 14 +- .../configuration_device_inventory.py | 249 ++++++++++++------ catalystwan/tests/test_endpoints.py | 23 ++ 4 files changed, 203 insertions(+), 91 deletions(-) diff --git a/catalystwan/api/config_device_inventory_api.py b/catalystwan/api/config_device_inventory_api.py index 2f5418a82..cf6d6fd5d 100644 --- a/catalystwan/api/config_device_inventory_api.py +++ b/catalystwan/api/config_device_inventory_api.py @@ -29,11 +29,11 @@ def unlock(self, device_uuid: str, device_type: str, device_details: list) -> Ta devices = [] for device_detail in device_details: unlock_device_detail = UnlockDeviceDetail( - deviceId=device_detail["deviceId"], deviceIP=device_detail["deviceIP"] + device_id=device_detail["deviceId"], device_ip=device_detail["deviceIP"] ) devices.append(unlock_device_detail) - payload = DeviceUnlockPayload(deviceType=device_type, devices=devices) + payload = DeviceUnlockPayload(device_type=device_type, devices=devices) task_id = self.endpoint.unlock(device_uuid=device_uuid, payload=payload).parentTaskId return Task(self.session, task_id=task_id) @@ -41,7 +41,7 @@ def unlock(self, device_uuid: str, device_type: str, device_details: list) -> Ta def generate_bootstrap_cfg( self, device_uuid: UUID, - configtype: ConfigType = ConfigType.CLOUDINIT, + configtype: ConfigType = "cloudinit", incl_def_root_cert: bool = False, version: str = "v1", ) -> BoostrapConfigurationDetails: @@ -49,7 +49,7 @@ def generate_bootstrap_cfg( Returns handy model of generated bootstrap config """ params = GenerateBoostrapConfigurationQueryParams( - configtype=configtype, inclDefRootCert=incl_def_root_cert, version=version + configtype=configtype, incl_def_root_cert=incl_def_root_cert, version=version ) reponse = self.endpoint.generate_bootstrap_configuration(uuid=device_uuid, params=params) diff --git a/catalystwan/endpoints/__init__.py b/catalystwan/endpoints/__init__.py index 7d9e5e214..fa9114364 100644 --- a/catalystwan/endpoints/__init__.py +++ b/catalystwan/endpoints/__init__.py @@ -43,6 +43,7 @@ Final, Iterable, List, + Literal, Mapping, Optional, Protocol, @@ -538,8 +539,19 @@ def check_params(self): raise APIEndpointError(f"Missing parameters: {missing} to format url: {self.url}") for parameter in [parameters.get(name) for name in self.url_field_names]: + # Check if 'params' is type of str, UUID or LIteral if not (isclass(parameter.annotation) and issubclass(parameter.annotation, (str, UUID))): - raise APIEndpointError(f"Parameter {parameter} used for url formatting must be 'str' sub-type or UUID") + if not get_origin(parameter.annotation) == Literal: + raise APIEndpointError( + f"Parameter {parameter} used for url formatting must be 'str', UUID or Literal sub-type" + ) + + elif p_args := get_args(parameter.annotation): + # Check if all 'params' Literal values are str + if not all((isinstance(arg, str) for arg in p_args)): + raise APIEndpointError( + f"Literal values for parameter {parameter} used for url formatting must be 'str'" + ) no_purpose_params = { parameters.get(name) for name in general_purpose_arg_names.difference(self.url_field_names) diff --git a/catalystwan/endpoints/configuration_device_inventory.py b/catalystwan/endpoints/configuration_device_inventory.py index 8d2efda64..925450368 100644 --- a/catalystwan/endpoints/configuration_device_inventory.py +++ b/catalystwan/endpoints/configuration_device_inventory.py @@ -1,7 +1,6 @@ # mypy: disable-error-code="empty-body" -from enum import Enum from pathlib import Path -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -12,12 +11,12 @@ class UnlockDeviceDetail(BaseModel): - device_id: str = Field(alias="deviceId") - device_ip: str = Field(alias="deviceIP") + device_id: str = Field(validation_alias="deviceId", serialization_alias="deviceId") + device_ip: str = Field(validation_alias="deviceIP", serialization_alias="deviceIP") class DeviceUnlockPayload(BaseModel): - device_type: str = Field(alias="deviceType") + device_type: str = Field(validation_alias="deviceType", serialization_alias="deviceType") devices: List[UnlockDeviceDetail] @@ -25,9 +24,7 @@ class DeviceUnlockResponse(BaseModel): parentTaskId: str -class Protocol(str, Enum): - DTLS = "DTLS" - TLS = "TLS" +Protocol = Literal["DTLS", "TLS"] class DeviceCreationPayload(BaseModel): @@ -38,108 +35,176 @@ class DeviceCreationPayload(BaseModel): password: str personality: Personality port: Optional[str] = Field(default=None) - protocol: Protocol = Protocol.DTLS + protocol: Protocol = Field(default="DTLS") username: str class DeviceDeletionResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) - local_delete_from_db: Optional[bool] = Field(default=None, alias="localDeleteFromDB") + local_delete_from_db: Optional[bool] = Field( + default=None, validation_alias="localDeleteFromDB", serialization_alias="localDeleteFromDB" + ) id: Optional[str] = Field(default=None) status: Optional[str] = Field(default=None) -class DeviceCategory(str, Enum): - CONTROLLERS = "controllers" - VEDGES = "vedges" +DeviceCategory = Literal["controllers", "vedges"] class DeviceDetailsResponse(BaseModel): # Field "model_sku" has conflict with protected namespace "model_" model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) - device_type: Optional[str] = Field(default=None, alias="deviceType") - serial_number: Optional[str] = Field(default=None, alias="serialNumber") + device_type: Optional[str] = Field(default=None, validation_alias="deviceType", serialization_alias="deviceType") + serial_number: Optional[str] = Field( + default=None, validation_alias="serialNumber", serialization_alias="serialNumber" + ) uuid: Optional[str] = None - management_system_ip: Optional[str] = Field(default=None, alias="managementSystemIP") - chasis_number: Optional[str] = Field(default=None, alias="chasisNumber") - config_operation_mode: Optional[str] = Field(default=None, alias="configOperationMode") - device_model: Optional[str] = Field(default=None, alias="deviceModel") - device_state: Optional[str] = Field(default=None, alias="deviceState") + management_system_ip: Optional[str] = Field( + default=None, validation_alias="managementSystemIP", serialization_alias="managementSystemIP" + ) + chasis_number: Optional[str] = Field( + default=None, validation_alias="chasisNumber", serialization_alias="chasisNumber" + ) + config_operation_mode: Optional[str] = Field( + default=None, validation_alias="configOperationMode", serialization_alias="configOperationMode" + ) + device_model: Optional[str] = Field(default=None, validation_alias="deviceModel", serialization_alias="deviceModel") + device_state: Optional[str] = Field(default=None, validation_alias="deviceState", serialization_alias="deviceState") validity: Optional[str] = None - platform_family: Optional[str] = Field(default=None, alias="platformFamily") + platform_family: Optional[str] = Field( + default=None, validation_alias="platformFamily", serialization_alias="platformFamily" + ) username: Optional[str] = None - device_csr: Optional[str] = Field(default=None, alias="deviceCSR") - device_csr_common_name: Optional[str] = Field(default=None, alias="deviceCSRCommonName") - root_cert_hash: Optional[str] = Field(default=None, alias="rootCertHash") - csr: Optional[str] = Field(default=None, alias="CSR") - csr_detail: Optional[str] = Field(default=None, alias="CSRDetail") + device_csr: Optional[str] = Field(default=None, validation_alias="deviceCSR", serialization_alias="deviceCSR") + device_csr_common_name: Optional[str] = Field( + default=None, validation_alias="deviceCSRCommonName", serialization_alias="deviceCSRCommonName" + ) + root_cert_hash: Optional[str] = Field( + default=None, validation_alias="rootCertHash", serialization_alias="rootCertHash" + ) + csr: Optional[str] = Field(default=None, validation_alias="CSR", serialization_alias="CSR") + csr_detail: Optional[str] = Field(default=None, validation_alias="CSRDetail", serialization_alias="CSRDetail") state: Optional[str] = None - global_state: Optional[str] = Field(default=None, alias="globalState") + global_state: Optional[str] = Field(default=None, validation_alias="globalState", serialization_alias="globalState") valid: Optional[str] = None - request_token_id: Optional[str] = Field(default=None, alias="requestTokenID") - expiration_date: Optional[str] = Field(default=None, alias="expirationDate") - expiration_date_long: Optional[int] = Field(default=None, alias="expirationDateLong") - device_ip: Optional[str] = Field(default=None, alias="deviceIP") + request_token_id: Optional[str] = Field( + default=None, validation_alias="requestTokenID", serialization_alias="requestTokenID" + ) + expiration_date: Optional[str] = Field( + default=None, validation_alias="expirationDate", serialization_alias="expirationDate" + ) + expiration_date_long: Optional[int] = Field( + default=None, validation_alias="expirationDateLong", serialization_alias="expirationDateLong" + ) + device_ip: Optional[str] = Field(default=None, validation_alias="deviceIP", serialization_alias="deviceIP") activity: Optional[List[str]] = None - state_vedge_list: Optional[str] = Field(default=None, alias="state_vedgeList") - cert_install_status: Optional[str] = Field(default=None, alias="certInstallStatus") + state_vedge_list: Optional[str] = Field( + default=None, validation_alias="state_vedgeList", serialization_alias="state_vedgeList" + ) + cert_install_status: Optional[str] = Field( + default=None, validation_alias="certInstallStatus", serialization_alias="certInstallStatus" + ) org: Optional[str] = None personality: Optional[str] = None - expiration_status: Optional[str] = Field(default=None, alias="expirationStatus") - life_cycle_required: Optional[bool] = Field(default=None, alias="lifeCycleRequired") - hardware_cert_serial_number: Optional[str] = Field(default=None, alias="hardwareCertSerialNumber") - subject_serial_number: Optional[str] = Field(default=None, alias="subjectSerialNumber") - resource_group: Optional[str] = Field(default=None, alias="resourceGroup") + expiration_status: Optional[str] = Field( + default=None, validation_alias="expirationStatus", serialization_alias="expirationStatus" + ) + life_cycle_required: Optional[bool] = Field( + default=None, validation_alias="lifeCycleRequired", serialization_alias="lifeCycleRequired" + ) + hardware_cert_serial_number: Optional[str] = Field( + default=None, validation_alias="hardwareCertSerialNumber", serialization_alias="hardwareCertSerialNumber" + ) + subject_serial_number: Optional[str] = Field( + default=None, validation_alias="subjectSerialNumber", serialization_alias="subjectSerialNumber" + ) + resource_group: Optional[str] = Field( + default=None, validation_alias="resourceGroup", serialization_alias="resourceGroup" + ) id: Optional[str] = None tags: Optional[List[str]] = None - draft_mode: Optional[str] = Field(default=None, alias="draftMode") + draft_mode: Optional[str] = Field(default=None, validation_alias="draftMode", serialization_alias="draftMode") solution: Optional[str] = None - device_lock: Optional[str] = Field(default=None, alias="device-lock") - managed_by: Optional[str] = Field(default=None, alias="managed-by") - configured_site_id: Optional[str] = Field(default=None, alias="configuredSiteId") - ncs_device_name: Optional[str] = Field(default=None, alias="ncsDeviceName") - config_status_message: Optional[str] = Field(default=None, alias="configStatusMessage") - template_apply_log: Optional[List[str]] = Field(default=None, alias="templateApplyLog") - template_status: Optional[str] = Field(default=None, alias="templateStatus") - config_status_message_details: Optional[str] = Field(default=None, alias="configStatusMessageDetails") - device_enterprise_certificate: Optional[str] = Field(default=None, alias="deviceEnterpriseCertificate") - service_personality: Optional[str] = Field(default=None, alias="servicePersonality") - upload_source: Optional[str] = Field(default=None, alias="uploadSource") - time_remaining_for_expiration: Optional[int] = Field(default=None, alias="timeRemainingForExpiration") - domain_id: Optional[str] = Field(default=None, alias="domain-id") - local_system_ip: Optional[str] = Field(default=None, alias="local-system-ip") - system_ip: Optional[str] = Field(default=None, alias="system-ip") - model_sku: Optional[str] = Field(default=None) - site_id: Optional[str] = Field(default=None, alias="site-id") - host_name: Optional[str] = Field(default=None, alias="host-name") - sp_organization_name: Optional[str] = Field(default=None, alias="sp-organization-name") - version: Optional[str] = Field(default=None) - vbond: Optional[str] = Field(default=None) - vmanage_system_ip: Optional[str] = Field(default=None, alias="vmanage-system-ip") - vmanage_connection_state: Optional[str] = Field(default=None, alias="vmanageConnectionState") - last_updated: Optional[int] = Field(default=None, alias="lastupdated") - reachability: Optional[str] = Field(default=None) - uptime_date: Optional[int] = Field(default=None, alias="uptime-date") - default_version: Optional[str] = Field(default=None, alias="defaultVersion") - organization_name: Optional[str] = Field(default=None, alias="organization-name") - available_versions: Optional[List[str]] = Field(default=None, alias="availableVersions") - site_name: Optional[str] = Field(default=None, alias="site-name") + device_lock: Optional[str] = Field(default=None, validation_alias="deviceLock", serialization_alias="deviceLock") + managed_by: Optional[str] = Field(default=None, validation_alias="managedBy", serialization_alias="managedBy") + configured_site_id: Optional[str] = Field( + default=None, validation_alias="configuredSiteId", serialization_alias="configuredSiteId" + ) + ncs_device_name: Optional[str] = Field( + default=None, validation_alias="ncsDeviceName", serialization_alias="ncsDeviceName" + ) + config_status_message: Optional[str] = Field( + default=None, validation_alias="configStatusMessage", serialization_alias="configStatusMessage" + ) + template_apply_log: Optional[List[str]] = Field( + default=None, validation_alias="templateApplyLog", serialization_alias="templateApplyLog" + ) + template_status: Optional[str] = Field( + default=None, validation_alias="templateStatus", serialization_alias="templateStatus" + ) + config_status_message_details: Optional[str] = Field( + default=None, validation_alias="configStatusMessageDetails", serialization_alias="configStatusMessageDetails" + ) + device_enterprise_certificate: Optional[str] = Field( + default=None, validation_alias="deviceEnterpriseCertificate", serialization_alias="deviceEnterpriseCertificate" + ) + service_personality: Optional[str] = Field( + default=None, validation_alias="servicePersonality", serialization_alias="servicePersonality" + ) + upload_source: Optional[str] = Field( + default=None, validation_alias="uploadSource", serialization_alias="uploadSource" + ) + time_remaining_for_expiration: Optional[int] = Field( + default=None, validation_alias="timeRemainingForExpiration", serialization_alias="timeRemainingForExpiration" + ) + domain_id: Optional[str] = Field(default=None, validation_alias="domainId", serialization_alias="domainId") + local_system_ip: Optional[str] = Field( + default=None, validation_alias="localSystemIp", serialization_alias="localSystemIp" + ) + system_ip: Optional[str] = Field(default=None, validation_alias="systemIp", serialization_alias="systemIp") + model_sku: Optional[str] = Field(default=None, validation_alias="modelSku", serialization_alias="modelSku") + site_id: Optional[str] = Field(default=None, validation_alias="siteId", serialization_alias="siteId") + host_name: Optional[str] = Field(default=None, validation_alias="hostName", serialization_alias="hostName") + sp_organization_name: Optional[str] = Field( + default=None, validation_alias="spOrganizationName", serialization_alias="spOrganizationName" + ) + version: Optional[str] = Field(default=None, validation_alias="version", serialization_alias="version") + vbond: Optional[str] = Field(default=None, validation_alias="vbond", serialization_alias="vbond") + vmanage_system_ip: Optional[str] = Field( + default=None, validation_alias="vmanageSystemIp", serialization_alias="vmanageSystemIp" + ) + vmanage_connection_state: Optional[str] = Field( + default=None, validation_alias="vmanageConnectionState", serialization_alias="vmanageConnectionState" + ) + last_updated: Optional[int] = Field(default=None, validation_alias="lastUpdated", serialization_alias="lastUpdated") + reachability: Optional[str] = Field( + default=None, validation_alias="reachability", serialization_alias="reachability" + ) + uptime_date: Optional[int] = Field(default=None, validation_alias="uptimeDate", serialization_alias="uptimeDate") + default_version: Optional[str] = Field( + default=None, validation_alias="defaultVersion", serialization_alias="defaultVersion" + ) + organization_name: Optional[str] = Field( + default=None, validation_alias="organizationName", serialization_alias="organizationName" + ) + available_versions: Optional[List[str]] = Field( + default=None, validation_alias="availableVersions", serialization_alias="availableVersions" + ) + site_name: Optional[str] = Field(default=None, validation_alias="siteName", serialization_alias="siteName") class DeviceDetailsQueryParams(BaseModel): model: Optional[str] = None state: Optional[List[str]] = None uuid: Optional[List[str]] = None - device_ip: Optional[List[str]] = Field(default=None, serialization_alias="deviceIP") + device_ip: Optional[List[str]] = Field(default=None, validation_alias="deviceIP", serialization_alias="deviceIP") validity: Optional[List[str]] = None family: Optional[str] = None -class Validity(str, Enum): - VALID = "valid" - INVALID = "invalid" +Validity = Literal["valid", "invalid"] class SmartAccountSyncParams(BaseModel): @@ -147,15 +212,15 @@ class SmartAccountSyncParams(BaseModel): password: str username: str - validity_string: str = Validity.INVALID + validity_string: Validity = Field(default="valid") class ProcessId(BaseModel): - process_id: str = Field(alias="processId") + process_id: str = Field(validation_alias="processId", serialization_alias="processId") class SerialFilePayload(CustomPayloadType): - def __init__(self, image_path: str, validity: Validity = Validity.INVALID): + def __init__(self, image_path: str, validity: Validity = "valid"): self.image_path = image_path self.validity = validity self.data = open(self.image_path, "rb") @@ -165,31 +230,43 @@ def prepared(self) -> PreparedPayload: return PreparedPayload(files={"file": (Path(self.data.name).name, self.data)}, data=self.fields) -class ConfigType(str, Enum): - CLOUDINIT = "cloudinit" - ENCODEDSTRING = "encodedstring" +ConfigType = Literal["cloudinit", "encodedstring"] class GenerateBoostrapConfigurationQueryParams(BaseModel): - configtype: Optional[ConfigType] = Field(default=ConfigType.CLOUDINIT) - incl_def_root_cert: Optional[bool] = Field(default=False, alias="inclDefRootCert") + model_config = ConfigDict(populate_by_name=True) + + configtype: Optional[ConfigType] = Field(default="cloudinit") + incl_def_root_cert: Optional[bool] = Field( + default=False, validation_alias="inclDefRootCert", serialization_alias="inclDefRootCert" + ) version: Optional[str] = Field(default="v1") class BoostrapConfiguration(BaseModel): model_config = ConfigDict(populate_by_name=True) - bootstrap_config: Optional[str] = Field(default=None, alias="bootstrapConfig") + bootstrap_config: Optional[str] = Field( + default=None, validation_alias="bootstrapConfig", serialization_alias="bootstrapConfig" + ) class UploadSerialFileResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) - vedge_list_upload_msg: Optional[str] = Field(default=None, alias="vedgeListUploadMsg") - vedge_list_upload_status: Optional[str] = Field(default=None, alias="vedgeListUploadStatus") + vedge_list_upload_msg: Optional[str] = Field( + default=None, validation_alias="vedgeListUploadMsg", serialization_alias="vedgeListUploadMsg" + ) + vedge_list_upload_status: Optional[str] = Field( + default=None, validation_alias="vedgeListUploadStatus", serialization_alias="vedgeListUploadStatus" + ) id: Optional[str] = None - vedge_list_status_code: Optional[str] = Field(default=None, alias="vedgeListStatusCode") - activity_list: Optional[Union[List, str]] = Field(default=None, alias="activityList") + vedge_list_status_code: Optional[str] = Field( + default=None, validation_alias="vedgeListStatusCode", serialization_alias="vedgeListStatusCode" + ) + activity_list: Optional[Union[List, str]] = Field( + default=None, validation_alias="activityList", serialization_alias="activityList" + ) class ConfigurationDeviceInventory(APIEndpoints): diff --git a/catalystwan/tests/test_endpoints.py b/catalystwan/tests/test_endpoints.py index ef4265e7f..b2715e463 100644 --- a/catalystwan/tests/test_endpoints.py +++ b/catalystwan/tests/test_endpoints.py @@ -858,6 +858,29 @@ class TestAPI(APIEndpoints): def get_data(self, fruit_type: FruitEnum) -> None: # type: ignore [empty-body] ... + def test_request_decorator_format_url_with_literal(self): + FruitType = Literal["banana", "orange", "apple"] + + class TestAPI(APIEndpoints): + @request("GET", "/v1/data/{fruit_type}") + def get_data(self, fruit_type: FruitType, payload: str) -> None: # type: ignore [empty-body] + ... + + api = TestAPI(self.session_mock) + # Act + api.get_data("banana", "not a fruit") + # Assert + self.session_mock.request.assert_called_once_with("GET", self.base_path + "/v1/data/banana", data="not a fruit") + + def test_request_decorator_raises_when_format_url_with_literal_is_not_str_subtype(self): + with self.assertRaises(APIEndpointError): + FruitType = Literal[1, 2, 3] + + class TestAPI(APIEndpoints): + @request("POST", "/v1/data/{fruit_type}") + def get_data(self, fruit_type: FruitType) -> None: # type: ignore [empty-body] + ... + def test_request_decorator_accept_union_of_models(self): class TestAPI(APIEndpoints): @request("GET", "/v1/data")