From fb346fe20f4c26167bb603451d70ccf1cb58ca88 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:18:03 +0200 Subject: [PATCH 01/32] =?UTF-8?q?=F0=9F=90=9B=20fix:=20remove=20unused=20i?= =?UTF-8?q?nit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/_hooks/hooks.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/coinapi/_hooks/hooks.py b/src/coinapi/_hooks/hooks.py index 1044192..da581fd 100644 --- a/src/coinapi/_hooks/hooks.py +++ b/src/coinapi/_hooks/hooks.py @@ -24,9 +24,6 @@ class SDKHooks(Hooks): after_success_hooks: ClassVar[list[AfterSuccessHook]] = [] after_error_hooks: ClassVar[list[AfterErrorHook]] = [] - def __init__(self) -> None: - pass - def register_sdk_init_hook(self, hook: SDKInitHook) -> None: """Register an SDK init hook.""" self.sdk_init_hooks.append(hook) From 04300faa87385b40e6241823dc9765e4e8aeda0d Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:26:43 +0200 Subject: [PATCH 02/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20base?= =?UTF-8?q?=20operations=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/base.py | 201 +++++++++++++++++++++++++++----------------- 1 file changed, 123 insertions(+), 78 deletions(-) diff --git a/src/coinapi/base.py b/src/coinapi/base.py index 696315e..407d753 100644 --- a/src/coinapi/base.py +++ b/src/coinapi/base.py @@ -1,14 +1,14 @@ """Base class for operation collections.""" import enum -from typing import TypeVar +from typing import Any, TypeVar import httpx import msgspec from httpx import codes from coinapi import utils -from coinapi._hooks import HookContext +from coinapi._hooks import BeforeRequestContext, HookContext from coinapi.config import CoinAPIConfig from coinapi.models import errors from coinapi.models.operations.base import CoinAPIRequest, CoinAPIResponse @@ -34,88 +34,115 @@ class Base: def __init__(self, sdk_config: CoinAPIConfig) -> None: self.sdk_configuration = sdk_config - def _make_request( # noqa: PLR0912, C901 + def _make_request( # type: ignore[return] self, operation_id: str, request: RequestT, response_cls: type[ResponseT], accept_header_override: AcceptEnum | None = None, ) -> ResponseT: - """Send a request.""" - hook_ctx = HookContext( + """Send an HTTP request.""" + hook_ctx = self._create_hook_context(operation_id) + prepared_request = self._prepare_request(request, accept_header_override) + client = self._configure_security_client() + + try: + http_res = self._execute_request(hook_ctx, prepared_request, client) + return self._process_response(http_res, response_cls) + except Exception as e: # noqa: BLE001 + self._handle_request_error(hook_ctx, e) + + def _create_hook_context(self, operation_id: str) -> BeforeRequestContext: + """Create a hook context.""" + return BeforeRequestContext( operation_id=operation_id, oauth2_scopes=[], security_source=self.sdk_configuration.security, ) + + def _prepare_request( + self, + request: RequestT, + accept_header_override: AcceptEnum | None, + ) -> httpx.Request: + """Prepare an HTTP request.""" base_url = utils.template_url(*self.sdk_configuration.get_server_details()) url = utils.generate_url(type(request), base_url, request.endpoint, request) # type: ignore[arg-type] + headers = self._prepare_headers(request, accept_header_override) + data, form = self._prepare_body(request) + query_params = utils.get_query_params(type(request), request) or None # type: ignore[arg-type] + + return httpx.Request( + request.method, + url, + params=query_params, + data=data, + files=form, + headers=headers, + ) + + def _prepare_headers( + self, + request: RequestT, + accept_header_override: AcceptEnum | None, + ) -> dict[str, str]: + """Prepare request headers.""" headers = {} - data, form = None, None if request.method in {"POST", "PUT", "PATCH"}: - req_content_type, data, form = utils.serialize_request_body( - request, - "body", - ) + req_content_type, _, _ = utils.serialize_request_body(request, "body") if req_content_type is not None and req_content_type not in ( "multipart/form-data", "multipart/mixed", ): headers["content-type"] = req_content_type - query_params = utils.get_query_params(type(request), request) or None # type: ignore[arg-type] - if accept_header_override is not None: - headers["Accept"] = accept_header_override.value - else: - headers["Accept"] = ( - "application/json;q=1, text/json;q=0.8, text/plain;q=0.5, application/x-msgpack;q=0" - ) + + headers["Accept"] = ( + accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/json;q=0.8, text/plain;q=0.5, application/x-msgpack;q=0" + ) headers["user-agent"] = self.sdk_configuration.user_agent + return headers + def _prepare_body(self, request: RequestT) -> tuple[Any, Any]: + """Prepare request body.""" + if request.method in {"POST", "PUT", "PATCH"}: + _, data, form = utils.serialize_request_body(request, "body") + return data, form + return None, None + + def _configure_security_client(self) -> utils.SecurityClient: + """Configure the security client.""" security = ( self.sdk_configuration.security() if callable(self.sdk_configuration.security) else self.sdk_configuration.security ) - client = utils.configure_security_client( - self.sdk_configuration.client, - security, - ) + return utils.configure_security_client(self.sdk_configuration.client, security) - try: - req = self.sdk_configuration.get_hooks().before_request( - hook_ctx, # type: ignore[arg-type] - httpx.Request( - request.method, - url, - params=query_params, - data=data, - files=form, - headers=headers, - ), - ) - http_res = client.send(req) - except Exception as e: - _, exc = self.sdk_configuration.get_hooks().after_error(hook_ctx, None, e) # type: ignore[arg-type] - raise exc from e # type: ignore[misc] + def _execute_request( + self, + hook_ctx: BeforeRequestContext, + prepared_request: httpx.Request, + client: utils.SecurityClient, + ) -> httpx.Response: + """Execute an HTTP request.""" + req = self.sdk_configuration.get_hooks().before_request( + hook_ctx, + prepared_request, + ) + return client.send(req) + def _process_response( + self, + http_res: httpx.Response, + response_cls: type[ResponseT], + ) -> ResponseT: + """Process an HTTP response.""" if utils.match_status_codes(["4XX", "5XX"], http_res.status_code): - http_res, exc = self.sdk_configuration.get_hooks().after_error( # type: ignore[assignment] - hook_ctx, # type: ignore[arg-type] - http_res, - None, - ) - if exc: - raise exc - else: - result = self.sdk_configuration.get_hooks().after_success( - hook_ctx, # type: ignore[arg-type] - http_res, - ) - if isinstance(result, Exception): - raise result - http_res = result + self._handle_error_response(http_res) content_type = http_res.headers.get("Content-Type", "") - res = response_cls( status_code=http_res.status_code, content_type=content_type, @@ -123,30 +150,45 @@ def _make_request( # noqa: PLR0912, C901 ) if httpx.codes.is_success(http_res.status_code): - if utils.match_content_type(content_type, "text/plain"): - res.content_plain = http_res.text - elif utils.match_content_type( - content_type, - "application/json", - ) or utils.match_content_type(content_type, "text/json"): - content_cls = next( - field - for field in msgspec.structs.fields(response_cls) - if field.name == "content" - ).type - out = msgspec.json.decode(http_res.content, type=content_cls) - res.content = out - elif utils.match_content_type(content_type, "application/x-msgpack"): - res.body = http_res.content - else: - msg = f"unknown content-type received: {content_type}" - raise errors.CoinAPIError( - msg, - http_res.status_code, - http_res.text, - http_res, - ) - elif codes.is_client_error(http_res.status_code) or codes.is_server_error( + self._set_response_content(res, http_res, content_type, response_cls) + + return res + + def _set_response_content( + self, + res: ResponseT, + http_res: httpx.Response, + content_type: str, + response_cls: type[ResponseT], + ) -> None: + """Set the response content.""" + if utils.match_content_type(content_type, "text/plain"): + res.content_plain = http_res.text + elif utils.match_content_type( + content_type, + "application/json", + ) or utils.match_content_type(content_type, "text/json"): + content_cls = next( + field + for field in msgspec.structs.fields(response_cls) + if field.name == "content" + ).type + out = msgspec.json.decode(http_res.content, type=content_cls) + res.content = out + elif utils.match_content_type(content_type, "application/x-msgpack"): + res.body = http_res.content + else: + msg = f"unknown content-type received: {content_type}" + raise errors.CoinAPIError( + msg, + http_res.status_code, + http_res.text, + http_res, + ) + + def _handle_error_response(self, http_res: httpx.Response) -> None: + """Handle an error response.""" + if codes.is_client_error(http_res.status_code) or codes.is_server_error( http_res.status_code, ): raise errors.CoinAPIError( @@ -156,4 +198,7 @@ def _make_request( # noqa: PLR0912, C901 http_res, ) - return res + def _handle_request_error(self, hook_ctx: HookContext, error: Exception) -> None: + """Handle a request error.""" + _, exc = self.sdk_configuration.get_hooks().after_error(hook_ctx, None, error) # type: ignore[arg-type] + raise exc from error # type: ignore[misc] From 4a4c027c1bae7efa1f5557f89be966acedd5eba8 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:28:55 +0200 Subject: [PATCH 03/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20use?= =?UTF-8?q?=20constant=20for=20exception=20message?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index f426a0c..3aa68d4 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -21,6 +21,8 @@ import msgspec from typing_inspect import is_optional_type +NOT_SUPPORTED = "not supported" + class SecurityClient: """Client with security settings.""" @@ -143,7 +145,7 @@ def _parse_security_scheme_value( elif sub_type == "query": client.query_params[header_name] = value else: - raise ValueError("not supported") + raise ValueError(NOT_SUPPORTED) elif scheme_type == "openIdConnect": client.headers[header_name] = _apply_bearer(value) elif scheme_type == "oauth2": @@ -153,9 +155,9 @@ def _parse_security_scheme_value( if sub_type == "bearer": client.headers[header_name] = _apply_bearer(value) else: - raise ValueError("not supported") + raise ValueError(NOT_SUPPORTED) else: - raise ValueError("not supported") + raise ValueError(NOT_SUPPORTED) def _apply_bearer(token: str) -> str: From 7151696d418aa7c597ae2f83c01e5ce5bcf1b2ab Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:36:26 +0200 Subject: [PATCH 04/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20gener?= =?UTF-8?q?ate=20url=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/base.py | 2 +- src/coinapi/utils/utils.py | 220 +++++++++++++++++++++++-------------- 2 files changed, 137 insertions(+), 85 deletions(-) diff --git a/src/coinapi/base.py b/src/coinapi/base.py index 407d753..20cb292 100644 --- a/src/coinapi/base.py +++ b/src/coinapi/base.py @@ -67,7 +67,7 @@ def _prepare_request( ) -> httpx.Request: """Prepare an HTTP request.""" base_url = utils.template_url(*self.sdk_configuration.get_server_details()) - url = utils.generate_url(type(request), base_url, request.endpoint, request) # type: ignore[arg-type] + url = utils.generate_url(type(request), base_url, request.endpoint, request) headers = self._prepare_headers(request, accept_header_override) data, form = self._prepare_body(request) query_params = utils.get_query_params(type(request), request) or None # type: ignore[arg-type] diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 3aa68d4..a6ce2d0 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -10,6 +10,7 @@ from typing import ( Any, ClassVar, + Protocol, Union, cast, get_args, @@ -197,8 +198,129 @@ def get_metadata(field_info: msgspec.structs.FieldInfo) -> dict[str, Any]: return {} -def generate_url( # noqa: PLR0912, C901 - clazz: msgspec.Struct, +class PathParamHandler(Protocol): + """Protocol for path parameter handlers.""" + + def handle(self, param: Any, metadata: dict[str, Any]) -> str: + """Handle a path parameter.""" + ... + + +class ListParamHandler: + """Handler for list parameters.""" + + def handle(self, param: list[Any], _metadata: dict[str, Any]) -> str: + """Handle a list parameter.""" + pp_vals = [_val_to_string(pp_val) for pp_val in param if pp_val is not None] + return ",".join(pp_vals) + + +class DictParamHandler: + """Handler for dictionary parameters.""" + + def handle(self, param: dict[str, Any], metadata: dict[str, Any]) -> str: + """Handle a dictionary parameter.""" + pp_vals = [] + for pp_key, pp_value in param.items(): + if pp_value is None: + continue + if metadata.get("explode"): + pp_vals.append(f"{pp_key}={_val_to_string(pp_value)}") + else: + pp_vals.append(f"{pp_key},{_val_to_string(pp_value)}") + return ",".join(pp_vals) + + +class StructParamHandler: + """Handler for msgspec.Struct parameters.""" + + def handle(self, param: msgspec.Struct, metadata: dict[str, Any]) -> str: + """Handle a msgspec.Struct parameter.""" + pp_vals = [] + param_fields: tuple[msgspec.structs.FieldInfo, ...] = msgspec.structs.fields( + param, + ) + for param_field in param_fields: + param_value_metadata = get_metadata(param_field).get("path_param") + if not param_value_metadata: + continue + + parm_name = param_value_metadata.get("field_name", param_field.name) + param_field_val = getattr(param, param_field.name) + if param_field_val is None: + continue + if metadata.get("explode"): + pp_vals.append(f"{parm_name}={_val_to_string(param_field_val)}") + else: + pp_vals.append(f"{parm_name},{_val_to_string(param_field_val)}") + return ",".join(pp_vals) + + +class DefaultParamHandler: + """Handler for default parameter types.""" + + def handle( + self, + param: str | complex | bool | Decimal, + _metadata: dict[str, Any], + ) -> str: + """Handle a default parameter type.""" + return _val_to_string(param) + + +def get_param_handler(param: Any) -> PathParamHandler: + """Get the appropriate parameter handler based on the parameter type.""" + if isinstance(param, list): + return ListParamHandler() + if isinstance(param, dict): + return DictParamHandler() + if isinstance(param, msgspec.Struct): + return StructParamHandler() + return DefaultParamHandler() + + +def handle_single_path_param(param: Any, metadata: dict[str, Any]) -> str: + """Handle a single path parameter.""" + handler = get_param_handler(param) + return handler.handle(param, metadata) + + +def serialize_param( + param: Any, + metadata: dict[str, Any], + field_type: type[Any], + field_name: str, +) -> dict[str, str]: + """Serialize a parameter based on metadata.""" + params: dict[str, str] = {} + serialization = metadata.get("serialization", "") + if serialization == "json": + params[metadata.get("field_name", field_name)] = marshal_json(param, field_type) + return params + + +def replace_url_placeholder(path: str, field_name: str, value: str) -> str: + """Replace a placeholder in the URL path.""" + return path.replace("{" + field_name + "}", value, 1) + + +def is_path_param(field: msgspec.structs.FieldInfo) -> bool: + """Check if a field is a path parameter.""" + return get_metadata(field).get("path_param") is not None + + +def get_param_value( + field: msgspec.structs.FieldInfo, + path_params: msgspec.Struct | None, + gbls: dict[str, dict[str, dict[str, Any]]] | None, +) -> Any: + """Get the parameter value from path_params or globals.""" + param = getattr(path_params, field.name) if path_params is not None else None + return _populate_from_globals(field.name, param, "pathParam", gbls) + + +def generate_url( + clazz: type[msgspec.Struct], server_url: str, path: str, path_params: msgspec.Struct | None, @@ -209,94 +331,24 @@ def generate_url( # noqa: PLR0912, C901 clazz, ) for field in path_param_fields: - request_metadata = get_metadata(field).get("request") - if request_metadata is not None: - continue - - param_metadata = get_metadata(field).get("path_param") - if param_metadata is None: + if not is_path_param(field): continue - param = getattr(path_params, field.name) if path_params is not None else None - param = _populate_from_globals(field.name, param, "pathParam", gbls) - + param = get_param_value(field, path_params, gbls) if param is None: continue - f_name = param_metadata.get("field_name", field.name) - serialization = param_metadata.get("serialization", "") - if serialization != "": - serialized_params = _get_serialized_params( - param_metadata, - field.type, - f_name, - param, - ) - for key, value in serialized_params.items(): - path = path.replace("{" + key + "}", value, 1) - elif param_metadata.get("style", "simple") == "simple": - if isinstance(param, list): - pp_vals: list[str] = [] - for pp_val in param: - if pp_val is None: - continue - pp_vals.append(_val_to_string(pp_val)) - path = path.replace( - "{" + param_metadata.get("field_name", field.name) + "}", - ",".join(pp_vals), - 1, - ) - elif isinstance(param, dict): - pp_vals = [] - for pp_key in param: - if param[pp_key] is None: - continue - if param_metadata.get("explode"): - pp_vals.append(f"{pp_key}={_val_to_string(param[pp_key])}") - else: - pp_vals.append(f"{pp_key},{_val_to_string(param[pp_key])}") - path = path.replace( - "{" + param_metadata.get("field_name", field.name) + "}", - ",".join(pp_vals), - 1, - ) - elif not isinstance(param, str | int | float | complex | bool | Decimal): - pp_vals = [] - param_fields: tuple[ - msgspec.structs.FieldInfo, - ..., - ] = msgspec.structs.fields(param) - for param_field in param_fields: - param_value_metadata = get_metadata(param_field).get( - "path_param", - ) - if not param_value_metadata: - continue - - parm_name = param_value_metadata.get("field_name", field.name) + metadata = get_metadata(field).get("path_param", {}) + f_name = metadata.get("field_name", field.name) + serialization = metadata.get("serialization", "") - param_field_val = getattr(param, param_field.name) - if param_field_val is None: - continue - if param_metadata.get("explode"): - pp_vals.append( - f"{parm_name}={_val_to_string(param_field_val)}", - ) - else: - pp_vals.append( - f"{parm_name},{_val_to_string(param_field_val)}", - ) - path = path.replace( - "{" + param_metadata.get("field_name", field.name) + "}", - ",".join(pp_vals), - 1, - ) - else: - path = path.replace( - "{" + param_metadata.get("field_name", field.name) + "}", - _val_to_string(param), - 1, - ) + if serialization: + serialized_params = serialize_param(param, metadata, field.type, f_name) + for key, value in serialized_params.items(): + path = replace_url_placeholder(path, key, value) + elif metadata.get("style", "simple") == "simple": + serialized_value = handle_single_path_param(param, metadata) + path = replace_url_placeholder(path, f_name, serialized_value) return remove_suffix(server_url, "/") + path From bbfcc0cdebb1c49303aff37e0e2c2e9667fad254 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:55:48 +0200 Subject: [PATCH 05/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20get?= =?UTF-8?q?=20query=20params?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/base.py | 2 +- src/coinapi/utils/utils.py | 141 +++++++++++++++++++++++++------------ 2 files changed, 98 insertions(+), 45 deletions(-) diff --git a/src/coinapi/base.py b/src/coinapi/base.py index 20cb292..ce389d1 100644 --- a/src/coinapi/base.py +++ b/src/coinapi/base.py @@ -70,7 +70,7 @@ def _prepare_request( url = utils.generate_url(type(request), base_url, request.endpoint, request) headers = self._prepare_headers(request, accept_header_override) data, form = self._prepare_body(request) - query_params = utils.get_query_params(type(request), request) or None # type: ignore[arg-type] + query_params = utils.get_query_params(type(request), request) or None return httpx.Request( request.method, diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index a6ce2d0..4d73fa4 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -366,62 +366,115 @@ def template_url(url_with_params: str, params: dict[str, str]) -> str: return url_with_params +class QueryParamHandler(Protocol): + """Protocol for query parameter handlers.""" + + def handle( + self, + metadata: dict[str, Any], + field_name: str, + value: Any, + ) -> dict[str, list[str]]: + """Handle query parameter processing.""" + ... + + +class FormQueryParamHandler: + """Handler for form style query parameters.""" + + def handle( + self, + metadata: dict[str, Any], + field_name: str, + value: Any, + ) -> dict[str, list[str]]: + """Process form style query parameters.""" + return _get_delimited_query_params(metadata, field_name, value, ",") + + +class DeepObjectQueryParamHandler: + """Handler for deepObject style query parameters.""" + + def handle( + self, + metadata: dict[str, Any], + field_name: str, + value: Any, + ) -> dict[str, list[str]]: + """Process deepObject style query parameters.""" + return _get_deep_object_query_params(metadata, field_name, value) + + +class PipeDelimitedQueryParamHandler: + """Handler for pipeDelimited style query parameters.""" + + def handle( + self, + metadata: dict[str, Any], + field_name: str, + value: Any, + ) -> dict[str, list[str]]: + """Process pipeDelimited style query parameters.""" + return _get_delimited_query_params(metadata, field_name, value, "|") + + +def get_query_param_handler(style: str) -> QueryParamHandler: + """Get the appropriate query parameter handler based on style.""" + handlers = { + "form": FormQueryParamHandler(), + "deepObject": DeepObjectQueryParamHandler(), + "pipeDelimited": PipeDelimitedQueryParamHandler(), + } + # Default to form style + return handlers.get(style, FormQueryParamHandler()) # type: ignore[return-value] + + +def process_query_param( + field: msgspec.structs.FieldInfo, + value: Any, + _gbls: dict[str, dict[str, dict[str, Any]]] | None, +) -> dict[str, list[str]]: + """Process a single query parameter.""" + metadata = get_metadata(field).get("query_param", {}) + field_name = metadata.get("field_name", field.name) + + if metadata.get("serialization"): + serialized_params = _get_serialized_params( + metadata, + field.type, + field_name, + value, + ) + return {key: [value] for key, value in serialized_params.items()} + + style = metadata.get("style", "form") + handler = get_query_param_handler(style) + return handler.handle(metadata, field_name, value) + + def get_query_params( - clazz: msgspec.Struct, + clazz: type[msgspec.Struct], query_params: msgspec.Struct, gbls: dict[str, dict[str, dict[str, Any]]] | None = None, ) -> dict[str, list[str]]: - """Get query parameters.""" + """Get query parameters for a request.""" params: dict[str, list[str]] = {} - param_fields: tuple[msgspec.structs.FieldInfo, ...] = msgspec.structs.fields(clazz) - for field in param_fields: - request_metadata = get_metadata(field).get("request") - if request_metadata is not None: + for field in msgspec.structs.fields(clazz): + if get_metadata(field).get("request") is not None: continue - metadata = get_metadata(field).get("query_param") - if not metadata: + if not get_metadata(field).get("query_param"): continue - param_name = field.name - value = getattr(query_params, param_name) if query_params is not None else None + value = getattr(query_params, field.name) if query_params is not None else None + value = _populate_from_globals(field.name, value, "queryParam", gbls) + + if value is None: + continue - value = _populate_from_globals(param_name, value, "queryParam", gbls) + params.update(process_query_param(field, value, gbls)) - f_name = metadata.get("field_name") - serialization = metadata.get("serialization", "") - if serialization != "": - serialized_parms = _get_serialized_params( - metadata, - field.type, - f_name, - value, - ) - for key, value in serialized_parms.items(): - if key in params: - params[key].extend(value) - else: - params[key] = [value] - else: - style = metadata.get("style", "form") - if style == "deepObject": - params = { - **params, - **_get_deep_object_query_params(metadata, f_name, value), - } - elif style == "form": - params = { - **params, - **_get_delimited_query_params(metadata, f_name, value, ","), - } - elif style == "pipeDelimited": - params = { - **params, - **_get_delimited_query_params(metadata, f_name, value, "|"), - } - else: - raise NotImplementedError("not yet implemented") return params From d0450fbd94de14fe1350a0984d8b95dd92ad5c22 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 15:59:19 +0200 Subject: [PATCH 06/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20upload=20correct=20c?= =?UTF-8?q?overage=20file?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index edc20c7..055514b 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -70,7 +70,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: coverage-${{ matrix.platform }}-${{ matrix.python-version }} - path: reports/.coverage + path: reports/coverage.xml coveralls-finish: needs: [python-test] From c287bb3453b0ccd64ca94cf583e0704f3f22bef2 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:24:52 +0200 Subject: [PATCH 07/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20simplify=20lint=20jo?= =?UTF-8?q?b=20and=20upload=20reports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 055514b..0856f55 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -7,18 +7,13 @@ on: jobs: python-lint: - strategy: - matrix: - python-version: ["3.10"] - platform: [ubuntu-latest] - fail-fast: false - runs-on: ${{ matrix.platform }} + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: "3.10" - uses: actions/cache@v4 id: cache with: @@ -33,6 +28,11 @@ jobs: run: python -m pdm install -G lint -G dev - name: Lint run: python -m pdm run lint + - name: Archive lint reports + uses: actions/upload-artifact@v4 + with: + name: lint-reports + path: reports python-test: strategy: From 4e1ea04bfa63acb99b1e1694efeb4575825dc97a Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:31:34 +0200 Subject: [PATCH 08/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20combine=20coverage?= =?UTF-8?q?=20reports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 0856f55..07df8ff 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -70,7 +70,8 @@ jobs: uses: actions/upload-artifact@v4 with: name: coverage-${{ matrix.platform }}-${{ matrix.python-version }} - path: reports/coverage.xml + path: reports/.coverage + include-hidden-files: true coveralls-finish: needs: [python-test] @@ -92,6 +93,22 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install coverage + run: pip install coverage + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: reports + pattern: coverage-* + merge-multiple: true + - name: Combine coverage + run: | + coverage combine reports + coverage xml - name: Generate sonar properties run: | cat << EOF > sonar-project.properties From 01e74ac755566affb80f8fe088aa729d0d5c67ca Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:32:24 +0200 Subject: [PATCH 09/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20get?= =?UTF-8?q?=5Fdeep=5Fobject=5Fquery=5Fparams?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 105 +++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 51 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 4d73fa4..21bfe8e 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -520,18 +520,29 @@ def _get_serialized_params( return params -def _get_deep_object_query_params( # noqa: PLR0912, C901 - metadata: dict[str, Any], - field_name: str, - obj: Any, -) -> dict[str, list[str]]: - """Get deep object query parameters.""" - params: dict[str, list[str]] = {} - - if obj is None: - return params - - if isinstance(obj, msgspec.Struct): +class DeepObjectQueryParamProcessor: + """Processor for deep object query parameters.""" + + def __init__(self, metadata: dict[str, Any], field_name: str) -> None: + """Initialize the processor.""" + self.metadata = metadata + self.field_name = field_name + self.params: dict[str, list[str]] = {} + + def process(self, obj: Any) -> dict[str, list[str]]: + """Process the input object.""" + if obj is None: + return self.params + if isinstance(obj, msgspec.Struct): + self._process_struct(obj) + elif isinstance(obj, dict): + self._process_dict(obj) + elif isinstance(obj, list): + self._process_list(obj) + return self.params + + def _process_struct(self, obj: msgspec.Struct) -> None: + """Process a msgspec.Struct object.""" obj_fields: tuple[msgspec.structs.FieldInfo, ...] = msgspec.structs.fields(obj) for obj_field in obj_fields: obj_param_metadata = get_metadata(obj_field).get("query_param") @@ -542,52 +553,44 @@ def _get_deep_object_query_params( # noqa: PLR0912, C901 if obj_val is None: continue - if isinstance(obj_val, list): - for val in obj_val: - if val is None: - continue + param_name = f'{self.metadata.get("field_name", self.field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' + self._add_param(param_name, obj_val) - if ( - params.get( - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]', - ) - is None - ): - params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' - ] = [] - - params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' - ].append(_val_to_string(val)) - else: - params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' - ] = [_val_to_string(obj_val)] - elif isinstance(obj, dict): + def _process_dict(self, obj: dict[str, Any]) -> None: + """Process a dictionary object.""" for key, value in obj.items(): if value is None: continue - if isinstance(value, list): - for val in value: - if val is None: - continue + param_name = f'{self.metadata.get("field_name", self.field_name)}[{key}]' + self._add_param(param_name, value) + + def _process_list(self, obj: list[Any]) -> None: + """Process a list object.""" + # This method is not used in the original implementation + # but added for completeness + raise NotImplementedError + + def _add_param(self, key: str, value: Any) -> None: + """Add a parameter to the result dictionary.""" + if isinstance(value, list): + if key not in self.params: + self.params[key] = [] + for val in value: + if val is not None: + self.params[key].append(_val_to_string(val)) + else: + self.params[key] = [_val_to_string(value)] - if ( - params.get(f'{metadata.get("field_name", field_name)}[{key}]') - is None - ): - params[f'{metadata.get("field_name", field_name)}[{key}]'] = [] - params[f'{metadata.get("field_name", field_name)}[{key}]'].append( - _val_to_string(val), - ) - else: - params[f'{metadata.get("field_name", field_name)}[{key}]'] = [ - _val_to_string(value), - ] - return params +def _get_deep_object_query_params( + metadata: dict[str, Any], + field_name: str, + obj: Any, +) -> dict[str, list[str]]: + """Get deep object query parameters.""" + processor = DeepObjectQueryParamProcessor(metadata, field_name) + return processor.process(obj) def _get_query_param_field_name(obj_field: msgspec.structs.FieldInfo) -> str: From 04448229f2770158c3429c97e375adf46376231e Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:34:20 +0200 Subject: [PATCH 10/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20install=20coverage?= =?UTF-8?q?=20with=20toml=20extra?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 07df8ff..642a579 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -98,7 +98,7 @@ jobs: with: python-version: "3.10" - name: Install coverage - run: pip install coverage + run: pip install coverage[toml] - name: Download artifacts uses: actions/download-artifact@v4 with: From 7bf62ad450de569bf0a3e9768c637c96cd37ec4f Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:38:01 +0200 Subject: [PATCH 11/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20check=20reports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 642a579..29e3e8a 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -107,6 +107,7 @@ jobs: merge-multiple: true - name: Combine coverage run: | + ls -al reports coverage combine reports coverage xml - name: Generate sonar properties From 5a07802c31db8de3c61192d7717d6043e1bea2e0 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:42:41 +0200 Subject: [PATCH 12/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20don't=20merge=20arti?= =?UTF-8?q?facts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 29e3e8a..58611ca 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -104,7 +104,6 @@ jobs: with: path: reports pattern: coverage-* - merge-multiple: true - name: Combine coverage run: | ls -al reports From 4f4fabd6dbf866201d6234c70701ad87c19f94fd Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:51:17 +0200 Subject: [PATCH 13/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20seria?= =?UTF-8?q?lize=20multipart=20form?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 178 +++++++++++++++++++++++++------------ 1 file changed, 122 insertions(+), 56 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 21bfe8e..4fcbe8f 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -687,62 +687,6 @@ def serialize_content_type( raise ValueError(msg) -def serialize_multipart_form( # noqa: PLR0912, C901 - media_type: str, - request: msgspec.Struct, -) -> tuple[str, Any, list[list[Any]]]: - """Serialize a multipart form.""" - form: list[list[Any]] = [] - request_fields = msgspec.structs.fields(request) - - for field in request_fields: - val = getattr(request, field.name) - if val is None: - continue - - field_metadata = get_metadata(field).get("multipart_form") - if not field_metadata: - continue - - if field_metadata.get("file") is True: - file_fields = msgspec.structs.fields(val) - - file_name = "" - field_name = "" - content = b"" - - for file_field in file_fields: - file_metadata = get_metadata(file_field).get("multipart_form") - if file_metadata is None: - continue - - if file_metadata.get("content") is True: - content = getattr(val, file_field.name) - else: - field_name = file_metadata.get("field_name", file_field.name) - file_name = getattr(val, file_field.name) - if field_name == "" or file_name == "" or content == b"": - raise ValueError("invalid multipart/form-data file") - - form.append([field_name, [file_name, content]]) - elif field_metadata.get("json") is True: - to_append = [ - field_metadata.get("field_name", field.name), - [None, marshal_json(val, field.type), "application/json"], - ] - form.append(to_append) - else: - field_name = field_metadata.get("field_name", field.name) - if isinstance(val, list): - for value in val: - if value is None: - continue - form.append([field_name + "[]", [None, _val_to_string(value)]]) - else: - form.append([field_name, [None, _val_to_string(val)]]) - return media_type, None, form - - def serialize_dict( original: dict[str, Any], explode: bool, # noqa: FBT001 @@ -769,6 +713,128 @@ def serialize_dict( return existing +class MultipartFormField: + """Represents a field in a multipart form.""" + + def __init__(self, name: str, value: Any, metadata: dict[str, Any]) -> None: + self.name = name + self.value = value + self.metadata = metadata + + +class FieldSerializer(Protocol): + """Protocol for field serializers.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a field.""" + ... + + +class FileFieldSerializer: + """Serializer for file fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a file field.""" + file_fields = msgspec.structs.fields(field.value) + file_name = "" + content = b"" + + for file_field in file_fields: + file_metadata = get_metadata(file_field).get("multipart_form") + if file_metadata is None: + continue + + if file_metadata.get("content") is True: + content = getattr(field.value, file_field.name) + else: + file_name = getattr(field.value, file_field.name) + + if not file_name or not content: + raise ValueError("Invalid multipart/form-data file") + + return [[field.name, [file_name, content]]] + + +class JsonFieldSerializer: + """Serializer for JSON fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a JSON field.""" + return [ + [ + field.metadata.get("field_name", field.name), + [ + None, + marshal_json(field.value, type(field.value)), + "application/json", + ], + ], + ] + + +class RegularFieldSerializer: + """Serializer for regular fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a regular field.""" + field_name = field.metadata.get("field_name", field.name) + if isinstance(field.value, list): + return [ + [f"{field_name}[]", [None, _val_to_string(value)]] + for value in field.value + if value is not None + ] + return [[field_name, [None, _val_to_string(field.value)]]] + + +class MultipartFormSerializer: + """Serializes a multipart form.""" + + def __init__(self) -> None: + self.serializers: dict[str, FieldSerializer] = { + "file": FileFieldSerializer(), + "json": JsonFieldSerializer(), + "regular": RegularFieldSerializer(), + } + + def serialize(self, request: msgspec.Struct) -> tuple[str, Any, list[list[Any]]]: + """Serialize the entire multipart form.""" + form: list[list[Any]] = [] + for field in self._get_fields(request): + serializer = self._get_serializer(field) + form.extend(serializer.serialize(field)) + return "multipart/form-data", None, form + + def _get_fields(self, request: msgspec.Struct) -> list[MultipartFormField]: + """Extract fields from the request.""" + fields = [] + for field in msgspec.structs.fields(request): + value = getattr(request, field.name) + if value is None: + continue + metadata = get_metadata(field).get("multipart_form", {}) + if metadata: + fields.append(MultipartFormField(field.name, value, metadata)) + return fields + + def _get_serializer(self, field: MultipartFormField) -> FieldSerializer: + """Get the appropriate serializer for a field.""" + if field.metadata.get("file") is True: + return self.serializers["file"] + if field.metadata.get("json") is True: + return self.serializers["json"] + return self.serializers["regular"] + + +def serialize_multipart_form( + _media_type: str, + request: msgspec.Struct, +) -> tuple[str, Any, list[list[Any]]]: + """Serialize a multipart form.""" + serializer = MultipartFormSerializer() + return serializer.serialize(request) + + def serialize_form_data(field_name: str, data: Any) -> dict[str, Any]: """Serialize form data.""" form: dict[str, list[str]] = {} From d9695ae7f03fd4f304035ae3e937cf58d7a0f18a Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:52:32 +0200 Subject: [PATCH 14/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20fix=20combining=20co?= =?UTF-8?q?verage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 58611ca..13ab061 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -107,8 +107,9 @@ jobs: - name: Combine coverage run: | ls -al reports - coverage combine reports + coverage combine reports/**/.coverage coverage xml + ls -al reports - name: Generate sonar properties run: | cat << EOF > sonar-project.properties From d6b9ab18e383edf148e68002f65fe08680dccc87 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:58:41 +0200 Subject: [PATCH 15/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20seria?= =?UTF-8?q?lize=20form=20data?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 60 +++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 4fcbe8f..59409c0 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -1,5 +1,6 @@ """Utilities.""" +import abc import base64 import datetime as dt import re @@ -835,11 +836,20 @@ def serialize_multipart_form( return serializer.serialize(request) -def serialize_form_data(field_name: str, data: Any) -> dict[str, Any]: - """Serialize form data.""" - form: dict[str, list[str]] = {} +class FormDataSerializer(metaclass=abc.ABCMeta): + """Base class for form data serializers.""" - if isinstance(data, msgspec.Struct): + @abc.abstractmethod + def serialize(self, field_name: str, data: Any) -> dict[str, list[str]]: + """Serialize form data.""" + + +class StructFormDataSerializer(FormDataSerializer): + """Serializer for Struct form data.""" + + def serialize(self, field_name: str, data: msgspec.Struct) -> dict[str, list[str]]: + """Serialize Struct form data.""" + form: dict[str, list[str]] = {} for field in msgspec.structs.fields(data): val = getattr(data, field.name) if val is None: @@ -854,27 +864,51 @@ def serialize_form_data(field_name: str, data: Any) -> dict[str, Any]: if metadata.get("json"): form[field_name] = [marshal_json(val, field.type)] elif metadata.get("style", "form") == "form": - form = { - **form, - **_populate_form( + form.update( + _populate_form( field_name, metadata.get("explode", True), val, _get_form_field_name, ",", ), - } + ) else: msg = f"Invalid form style for field {field.name}" raise ValueError(msg) - elif isinstance(data, dict): - for key, value in data.items(): - form[key] = [_val_to_string(value)] - else: + return form + + +class DictFormDataSerializer(FormDataSerializer): + """Serializer for Dict form data.""" + + def serialize(self, _field_name: str, data: dict[str, Any]) -> dict[str, list[str]]: + """Serialize Dict form data.""" + return {key: [_val_to_string(value)] for key, value in data.items()} + + +class DefaultFormDataSerializer(FormDataSerializer): + """Serializer for default form data.""" + + def serialize(self, field_name: str, _data: Any) -> dict[str, list[str]]: + """Serialize default form data.""" msg = f"Invalid request body type for field {field_name}" raise TypeError(msg) - return form + +def get_form_data_serializer(data: Any) -> FormDataSerializer: + """Get the appropriate form data serializer.""" + if isinstance(data, msgspec.Struct): + return StructFormDataSerializer() + if isinstance(data, dict): + return DictFormDataSerializer() + return DefaultFormDataSerializer() + + +def serialize_form_data(field_name: str, data: Any) -> dict[str, list[str]]: + """Serialize form data.""" + serializer = get_form_data_serializer(data) + return serializer.serialize(field_name, data) def _get_form_field_name(obj_field: msgspec.structs.FieldInfo) -> str: From 867aea97f866f3567877416aa928bdd9c5fa520f Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 16:59:05 +0200 Subject: [PATCH 16/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20remove=20lses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 13ab061..2115f62 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -106,10 +106,8 @@ jobs: pattern: coverage-* - name: Combine coverage run: | - ls -al reports coverage combine reports/**/.coverage coverage xml - ls -al reports - name: Generate sonar properties run: | cat << EOF > sonar-project.properties From 2cc33d9ec2d6718b358d23882d5e66a3507245ad Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 17:05:51 +0200 Subject: [PATCH 17/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20popul?= =?UTF-8?q?ate=20form?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 146 +++++++++++++++++++++++++++++++------ 1 file changed, 122 insertions(+), 24 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 59409c0..967a9cd 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -613,10 +613,10 @@ def _get_delimited_query_params( """Get delimited query parameters.""" return _populate_form( field_name, - metadata.get("explode", True), obj, _get_query_param_field_name, delimiter, + explode=metadata.get("explode", True), ) @@ -867,10 +867,10 @@ def serialize(self, field_name: str, data: msgspec.Struct) -> dict[str, list[str form.update( _populate_form( field_name, - metadata.get("explode", True), val, _get_form_field_name, ",", + explode=metadata.get("explode", True), ), ) else: @@ -921,20 +921,36 @@ def _get_form_field_name(obj_field: msgspec.structs.FieldInfo) -> str: return cast(str, obj_param_metadata.get("field_name", obj_field.name)) -def _populate_form( # noqa: PLR0912, C901 - field_name: str, - explode: bool, # noqa: FBT001 - obj: Any, - get_field_name_func: Callable[..., str], - delimiter: str, -) -> dict[str, list[str]]: - """Populate a form.""" - params: dict[str, list[str]] = {} +class FormPopulator(metaclass=abc.ABCMeta): + """Abstract base class for form populators.""" - if obj is None: - return params + @abc.abstractmethod + def populate( + self, + field_name: str, + obj: Any, + get_field_name_func: Callable[..., str], + delimiter: str, + *, + explode: bool, + ) -> dict[str, list[str]]: + """Populate form data.""" - if isinstance(obj, msgspec.Struct): + +class StructFormPopulator(FormPopulator): + """Populator for msgspec.Struct objects.""" + + def populate( + self, + field_name: str, + obj: Any, + get_field_name_func: Callable[..., str], + delimiter: str, + *, + explode: bool, + ) -> dict[str, list[str]]: + """Populate form data for Struct objects.""" + params: dict[str, list[str]] = {} items = [] obj_fields: tuple[msgspec.structs.FieldInfo, ...] = msgspec.structs.fields(obj) @@ -952,22 +968,57 @@ def _populate_form( # noqa: PLR0912, C901 else: items.append(f"{obj_field_name}{delimiter}{_val_to_string(val)}") - if len(items) > 0: + if items: params[field_name] = [delimiter.join(items)] - elif isinstance(obj, dict): + + return params + + +class DictFormPopulator(FormPopulator): + """Populator for dictionary objects.""" + + def populate( + self, + field_name: str, + obj: dict[str, Any], + _get_field_name_func: Callable[..., str], + delimiter: str, + *, + explode: bool, + ) -> dict[str, list[str]]: + """Populate form data for dictionary objects.""" + params: dict[str, list[str]] = {} items = [] + for key, value in obj.items(): if value is None: continue if explode: - params[key] = _val_to_string(value) # type: ignore[assignment] + params[key] = [_val_to_string(value)] else: items.append(f"{key}{delimiter}{_val_to_string(value)}") - if len(items) > 0: + if items: params[field_name] = [delimiter.join(items)] - elif isinstance(obj, list): + + return params + + +class ListFormPopulator(FormPopulator): + """Populator for list objects.""" + + def populate( + self, + field_name: str, + obj: list[Any], + _get_field_name_func: Callable[..., str], + delimiter: str, + *, + explode: bool, + ) -> dict[str, list[str]]: + """Populate form data for list objects.""" + params: dict[str, list[str]] = {} items = [] for value in obj: @@ -981,12 +1032,59 @@ def _populate_form( # noqa: PLR0912, C901 else: items.append(_val_to_string(value)) - if len(items) > 0: - params[field_name] = [delimiter.join([str(item) for item in items])] - else: - params[field_name] = [_val_to_string(obj)] + if items: + params[field_name] = [delimiter.join(items)] - return params + return params + + +class DefaultFormPopulator(FormPopulator): + """Default populator for other object types.""" + + def populate( + self, + field_name: str, + obj: Any, + _get_field_name_func: Callable[..., str], + _delimiter: str, + *, + explode: bool, # noqa: ARG002 + ) -> dict[str, list[str]]: + """Populate form data for default object types.""" + return {field_name: [_val_to_string(obj)]} + + +def get_form_populator(obj: Any) -> FormPopulator: + """Get the appropriate form populator based on object type.""" + if isinstance(obj, msgspec.Struct): + return StructFormPopulator() + if isinstance(obj, dict): + return DictFormPopulator() + if isinstance(obj, list): + return ListFormPopulator() + return DefaultFormPopulator() + + +def _populate_form( + field_name: str, + obj: Any, + get_field_name_func: Callable[..., str], + delimiter: str, + *, + explode: bool, +) -> dict[str, list[str]]: + """Populate a form using the appropriate populator.""" + if obj is None: + return {} + + populator = get_form_populator(obj) + return populator.populate( + field_name, + obj, + get_field_name_func, + delimiter, + explode=explode, + ) def _serialize_header(explode: bool, obj: Any) -> str: # noqa: PLR0912, C901, FBT001 From 49012a2ef7fb16a473b4c86872260ded53805cef Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 17:19:42 +0200 Subject: [PATCH 18/32] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20refactor:=20seria?= =?UTF-8?q?lize=20header?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 105 +++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 44 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 967a9cd..b74ee9e 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -21,6 +21,7 @@ import httpx import msgspec +from mypy_extensions import NamedArg from typing_inspect import is_optional_type NOT_SUPPORTED = "not supported" @@ -495,8 +496,8 @@ def get_headers(headers_params: msgspec.Struct | None) -> dict[str, str]: continue value = _serialize_header( - metadata.get("explode", False), getattr(headers_params, field.name), + explode=metadata.get("explode", False), ) if value != "": @@ -1087,66 +1088,82 @@ def _populate_form( ) -def _serialize_header(explode: bool, obj: Any) -> str: # noqa: PLR0912, C901, FBT001 +def _serialize_header(obj: Any, *, explode: bool) -> str: """Serialize a header.""" if obj is None: return "" + serializer = _get_header_serializer(obj) + return serializer(obj, explode=explode) + + +def _get_header_serializer(obj: Any) -> Callable[[Any, NamedArg(bool, "explode")], str]: + """Get the appropriate header serializer based on object type.""" if isinstance(obj, msgspec.Struct): - items = [] - obj_fields: tuple[msgspec.structs.FieldInfo, ...] = msgspec.structs.fields(obj) - for obj_field in obj_fields: - obj_param_metadata = get_metadata(obj_field).get("header") + return _serialize_struct_header + if isinstance(obj, dict): + return _serialize_dict_header + if isinstance(obj, list): + return _serialize_list_header + return _serialize_simple_header - if not obj_param_metadata: - continue - obj_field_name = obj_param_metadata.get("field_name", obj_field.name) - if obj_field_name == "": - continue +def _serialize_struct_header(obj: msgspec.Struct, *, explode: bool) -> str: + """Serialize a msgspec.Struct header.""" + items = _get_struct_items(obj, explode=explode) + return ",".join(items) if items else "" - val = getattr(obj, obj_field.name) - if val is None: - continue - if explode: - items.append(f"{obj_field_name}={_val_to_string(val)}") - else: - items.append(obj_field_name) - items.append(_val_to_string(val)) +def _get_struct_items(obj: msgspec.Struct, *, explode: bool) -> list[str]: + """Get serialized items from a msgspec.Struct.""" + items = [] + for obj_field in msgspec.structs.fields(obj): + obj_param_metadata = get_metadata(obj_field).get("header") + if not obj_param_metadata: + continue - if len(items) > 0: - return ",".join(items) - elif isinstance(obj, dict): - items = [] + obj_field_name = obj_param_metadata.get("field_name", obj_field.name) + if obj_field_name == "": + continue - for key, value in obj.items(): - if value is None: - continue + val = getattr(obj, obj_field.name) + if val is None: + continue - if explode: - items.append(f"{key}={_val_to_string(value)}") - else: - items.append(key) - items.append(_val_to_string(value)) + items.extend(_format_item(obj_field_name, val, explode=explode)) + return items - if len(items) > 0: - return ",".join([str(item) for item in items]) - elif isinstance(obj, list): - items = [] - for value in obj: - if value is None: - continue +def _serialize_dict_header(obj: dict[str, Any], *, explode: bool) -> str: + """Serialize a dictionary header.""" + items = _get_dict_items(obj, explode=explode) + return ",".join(str(item) for item in items) if items else "" - items.append(_val_to_string(value)) - if len(items) > 0: - return ",".join(items) - else: - return f"{_val_to_string(obj)}" +def _get_dict_items(obj: dict[str, Any], *, explode: bool) -> list[str]: + """Get serialized items from a dictionary.""" + items = [] + for key, value in obj.items(): + if value is not None: + items.extend(_format_item(key, value, explode=explode)) + return items + + +def _serialize_list_header(obj: list[Any], *, explode: bool) -> str: # noqa: ARG001 + """Serialize a list header.""" + return ",".join(_val_to_string(value) for value in obj if value is not None) + + +def _serialize_simple_header(obj: Any, *, explode: bool) -> str: # noqa: ARG001 + """Serialize a simple header.""" + return _val_to_string(obj) + - return "" +def _format_item(key: str, value: Any, *, explode: bool) -> list[str]: + """Format a key-value pair for header serialization.""" + if explode: + return [f"{key}={_val_to_string(value)}"] + return [key, _val_to_string(value)] def marshal_json( From 78f02b7d0022178c55ccde52dccf822ea9704c8f Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 17:21:41 +0200 Subject: [PATCH 19/32] =?UTF-8?q?=F0=9F=90=9B=20fix:=20rename=20unused=20p?= =?UTF-8?q?aram?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index b74ee9e..11d32b3 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -848,7 +848,7 @@ def serialize(self, field_name: str, data: Any) -> dict[str, list[str]]: class StructFormDataSerializer(FormDataSerializer): """Serializer for Struct form data.""" - def serialize(self, field_name: str, data: msgspec.Struct) -> dict[str, list[str]]: + def serialize(self, _field_name: str, data: msgspec.Struct) -> dict[str, list[str]]: """Serialize Struct form data.""" form: dict[str, list[str]] = {} for field in msgspec.structs.fields(data): From 5b4d2d325687fb1ea3d695ee8e29d52dcc356a93 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 17:26:53 +0200 Subject: [PATCH 20/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20use=20only=20latest?= =?UTF-8?q?=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 2115f62..311fda2 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -102,11 +102,11 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4 with: + name: coverage-ubuntu-latest-3.12 path: reports - pattern: coverage-* - - name: Combine coverage + - name: Create coverage XML run: | - coverage combine reports/**/.coverage + ls -al reports coverage xml - name: Generate sonar properties run: | From cbf0d0e70316af111df2e98d9d784f1d0379cefe Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 17:57:26 +0200 Subject: [PATCH 21/32] =?UTF-8?q?=F0=9F=94=A7=20config(coverage):=20try=20?= =?UTF-8?q?another=20workaround?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bce4345..c24ff36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,15 +60,19 @@ branch = true command_line = "--module pytest" data_file = "reports/.coverage" source = ["src"] +# include = ["src/*"] +# omit = ["tests/*"] [tool.coverage.paths] -source = ["src/", "/home/runner/**/src", "D:\\**\\src"] +source = ["src/"] [tool.coverage.report] fail_under = 50 precision = 1 show_missing = true skip_covered = true +# include = ["src/*"] +# omit = ["tests/*"] [tool.coverage.xml] output = "reports/coverage.xml" From 4478c89aa7e99091644c28f15831b16e1e327085 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 18:03:38 +0200 Subject: [PATCH 22/32] =?UTF-8?q?=F0=9F=94=A7=20config(coverage):=20add=20?= =?UTF-8?q?workaround=20for=20source=20path=20issue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit for the issue see [1] for the workaround see [2] [1]: https://github.com/nedbat/coveragepy/issues/578 [2]: https://github.com/LibraryOfCongress/concordia/pull/857/files --- pyproject.toml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c24ff36..04a9e48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,9 +59,8 @@ warn_unreachable = true branch = true command_line = "--module pytest" data_file = "reports/.coverage" -source = ["src"] -# include = ["src/*"] -# omit = ["tests/*"] +include = ["src/*"] +omit = ["tests/*"] [tool.coverage.paths] source = ["src/"] @@ -71,8 +70,8 @@ fail_under = 50 precision = 1 show_missing = true skip_covered = true -# include = ["src/*"] -# omit = ["tests/*"] +include = ["src/*"] +omit = ["tests/*"] [tool.coverage.xml] output = "reports/coverage.xml" From e5098bfbe14facd61640404a017ca883378cde5a Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 18:07:09 +0200 Subject: [PATCH 23/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20update=20triggers=20?= =?UTF-8?q?to=20run=20on=20PRs=20too?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 311fda2..61b46f0 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -2,6 +2,13 @@ name: Python package on: push: + branches: + - "main" + pull_request: + branches: + - "**" + types: [opened, synchronize, reopened] + create: branches: - "**" From 809d71f453386a125294e86c983c32a367204c3e Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 19:21:49 +0200 Subject: [PATCH 24/32] =?UTF-8?q?=E2=9C=85=20test:=20add=20utils=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 577 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 577 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..11ff062 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,577 @@ +"""Tests for the utils module.""" + +import base64 +import datetime as dt +import enum +from decimal import Decimal +from typing import Annotated, Any, Union + +import httpx +import msgspec + +from coinapi.utils import utils + + +class TestSecurityClient: + """Tests for SecurityClient.""" + + def test_init(self) -> None: + """Test initialization of SecurityClient.""" + client = utils.SecurityClient() + assert client.client is None + assert client.timeout == 60 + assert isinstance(client.limits, httpx.Limits) + + def test_send_with_client(self) -> None: + """Test send method with a client.""" + mock_client = httpx.Client() + client = utils.SecurityClient(client=mock_client) + request = httpx.Request("GET", "https://example.com") + response = client.send(request) + assert isinstance(response, httpx.Response) + + def test_send_without_client(self) -> None: + """Test send method without a client.""" + client = utils.SecurityClient() + request = httpx.Request("GET", "https://example.com") + response = client.send(request) + assert isinstance(response, httpx.Response) + + +class TestConfigureSecurityClient: + """Tests for configure_security_client.""" + + def test_configure_security_client_no_security(self) -> None: + """Test configuring security client with no security.""" + client = utils.configure_security_client(None, None) + assert isinstance(client, utils.SecurityClient) + assert client.headers == {} + assert client.query_params == {} + + def test_configure_security_client_with_api_key(self) -> None: + """Test configuring security client with API key.""" + + class Security(msgspec.Struct): + api_key: Annotated[ + str, + msgspec.Meta( + extra={ + "security": { + "scheme": True, + "type": "apiKey", + "sub_type": "header", + "field_name": "X-API-Key", + }, + }, + ), + ] + + security = Security(api_key="test_key") + client = utils.configure_security_client(None, security) + assert client.headers == {"X-API-Key": "test_key"} + + def test_configure_security_client_with_basic_auth(self) -> None: + """Test configuring security client with basic auth.""" + + class Security(msgspec.Struct): + username: Annotated[ + str, + msgspec.Meta( + extra={ + "security": { + "scheme": True, + "type": "http", + "sub_type": "basic", + "field_name": "username", + }, + }, + ), + ] + password: Annotated[ + str, + msgspec.Meta( + extra={ + "security": { + "scheme": True, + "type": "http", + "sub_type": "basic", + "field_name": "password", + }, + }, + ), + ] + + security = Security(username="user", password="pass") # noqa: S106 + client = utils.configure_security_client(None, security) + expected_auth = base64.b64encode(b"user:pass").decode() + assert client.headers == {"Authorization": f"Basic {expected_auth}"} + + +class TestPathParamHandlers: + """Tests for path param handlers.""" + + def test_list_param_handler(self) -> None: + """Test ListParamHandler.""" + handler = utils.ListParamHandler() + result = handler.handle([1, "two", 3.0], {}) + assert result == "1,two,3.0" + + def test_dict_param_handler(self) -> None: + """Test DictParamHandler.""" + handler = utils.DictParamHandler() + result = handler.handle({"a": 1, "b": "two"}, {"explode": True}) + assert result == "a=1,b=two" + + def test_struct_param_handler(self) -> None: + """Test StructParamHandler.""" + + class TestStruct(msgspec.Struct): + a: Annotated[ + int, + msgspec.Meta(extra={"path_param": {"field_name": "a"}}), + ] + b: Annotated[ + str, + msgspec.Meta(extra={"path_param": {"field_name": "b"}}), + ] + + handler = utils.StructParamHandler() + result = handler.handle(TestStruct(a=1, b="two"), {"explode": True}) + assert result == "a=1,b=two" + + def test_default_param_handler(self) -> None: + """Test DefaultParamHandler.""" + handler = utils.DefaultParamHandler() + result = handler.handle(42, {}) + assert result == "42" + + +class TestHandleSinglePathParam: + """Tests for handle_single_path_param.""" + + def test_handle_single_path_param_list(self) -> None: + """Test handling single path param for list.""" + result = utils.handle_single_path_param([1, 2, 3], {}) + assert result == "1,2,3" + + def test_handle_single_path_param_dict(self) -> None: + """Test handling single path param for dict.""" + result = utils.handle_single_path_param({"a": 1, "b": 2}, {"explode": True}) + assert result == "a=1,b=2" + + def test_handle_single_path_param_struct(self) -> None: + """Test handling single path param for struct.""" + + class TestStruct(msgspec.Struct): + a: Annotated[ + int, + msgspec.Meta(extra={"path_param": {"field_name": "a"}}), + ] + b: Annotated[ + str, + msgspec.Meta(extra={"path_param": {"field_name": "b"}}), + ] + + result = utils.handle_single_path_param( + TestStruct(a=1, b="two"), + {"explode": True}, + ) + assert result == "a=1,b=two" + + def test_handle_single_path_param_default(self) -> None: + """Test handling single path param for default case.""" + result = utils.handle_single_path_param(42, {}) + assert result == "42" + + +class TestSerializeParam: + """Tests for serialize_param.""" + + def test_serialize_param_json(self) -> None: + """Test serializing param to JSON.""" + + class TestStruct(msgspec.Struct): + a: int + b: str + + result = utils.serialize_param( + TestStruct(a=1, b="two"), + {"serialization": "json", "field_name": "test"}, + TestStruct, + "test", + ) + assert result == {"test": '{"a":1,"b":"two"}'} + + def test_serialize_param_non_json(self) -> None: + """Test serializing param for non-JSON case.""" + result = utils.serialize_param(42, {}, int, "test") + assert result == {} + + +class TestReplaceUrlPlaceholder: + """Tests for replace_url_placeholder.""" + + def test_replace_url_placeholder(self) -> None: + """Test replacing URL placeholder.""" + path = "/api/{version}/users/{id}" + result = utils.replace_url_placeholder(path, "version", "v1") + assert result == "/api/v1/users/{id}" + + +class TestIsPathParam: + """Tests for is_path_param.""" + + def test_is_path_param_true(self) -> None: + """Test is_path_param for true case.""" + + class TestStruct(msgspec.Struct): + param: Annotated[str, msgspec.Meta(extra={"path_param": {}})] + + field = msgspec.structs.fields(TestStruct)[0] + assert utils.is_path_param(field) is True + + def test_is_path_param_false(self) -> None: + """Test is_path_param for false case.""" + + class TestStruct(msgspec.Struct): + param: str + + field = msgspec.structs.fields(TestStruct)[0] + assert utils.is_path_param(field) is False + + +class TestGetParamValue: + """Tests for get_param_value.""" + + def test_get_param_value_from_path_params(self) -> None: + """Test getting param value from path params.""" + + class PathParams(msgspec.Struct): + param: str + + field = msgspec.structs.fields(PathParams)[0] + path_params = PathParams(param="value") + result = utils.get_param_value(field, path_params, None) + assert result == "value" + + def test_get_param_value_from_globals(self) -> None: + """Test getting param value from globals.""" + + class PathParams(msgspec.Struct): + param: str + + field = msgspec.structs.fields(PathParams)[0] + gbls = {"parameters": {"pathParam": {"param": "global_value"}}} + result = utils.get_param_value(field, None, gbls) + assert result == "global_value" + + +class TestGenerateUrl: + """Tests for generate_url.""" + + def test_generate_url(self) -> None: + """Test generating URL.""" + + class PathParams(msgspec.Struct): + version: Annotated[ + str, + msgspec.Meta( + extra={"path_param": {"field_name": "version"}}, + ), + ] + id: Annotated[ + int, + msgspec.Meta(extra={"path_param": {"field_name": "id"}}), + ] + + path_params = PathParams(version="v1", id=123) + result = utils.generate_url( + PathParams, + "https://api.example.com", + "/api/{version}/users/{id}", + path_params, + ) + assert result == "https://api.example.com/api/v1/users/123" + + +class TestIsOptional: + """Tests for is_optional.""" + + def test_is_optional_true(self) -> None: + """Test is_optional for true case.""" + # TODO: this only works with `Union` which is not how optionals are + # used, fix + assert utils.is_optional(Union[str, None]) is True # type: ignore[arg-type] # noqa: UP007 + + def test_is_optional_false(self) -> None: + """Test is_optional for false case.""" + assert utils.is_optional(str) is False + + +class TestTemplateUrl: + """Tests for template_url.""" + + def test_template_url(self) -> None: + """Test templating URL.""" + url = "https://api.example.com/{version}/users/{id}" + params = {"version": "v1", "id": "123"} + result = utils.template_url(url, params) + assert result == "https://api.example.com/v1/users/123" + + def test_init(self) -> None: + """Test initialization of SecurityClient.""" + client = utils.SecurityClient() + assert client.client is None + assert client.timeout == 60 + assert isinstance(client.limits, httpx.Limits) + + def test_send_with_client(self) -> None: + """Test send method with a client.""" + mock_client = httpx.Client() + client = utils.SecurityClient(client=mock_client) + request = httpx.Request("GET", "https://example.com") + response = client.send(request) + assert isinstance(response, httpx.Response) + + def test_send_without_client(self) -> None: + """Test send method without a client.""" + client = utils.SecurityClient() + request = httpx.Request("GET", "https://example.com") + response = client.send(request) + assert isinstance(response, httpx.Response) + + +class TestProcessQueryParam: + """Tests for process_query_param.""" + + def test_process_query_param_form(self) -> None: + """Test processing query param with form style.""" + + class QueryParams(msgspec.Struct): + param: Annotated[ + list[str], + msgspec.Meta( + extra={"query_param": {"style": "form", "explode": True}}, + ), + ] + + field = msgspec.structs.fields(QueryParams)[0] + result = utils.process_query_param(field, ["a", "b", "c"], None) + assert result == {"param": ["a", "b", "c"]} + + def test_process_query_param_deep_object(self) -> None: + """Test processing query param with deepObject style.""" + + class QueryParams(msgspec.Struct): + param: Annotated[ + dict[str, int], + msgspec.Meta( + extra={"query_param": {"style": "deepObject", "explode": True}}, + ), + ] + + field = msgspec.structs.fields(QueryParams)[0] + result = utils.process_query_param(field, {"x": 1, "y": 2}, None) + assert result == {"param[x]": ["1"], "param[y]": ["2"]} + + +class TestGetQueryParams: + """Tests for get_query_params.""" + + def test_get_query_params(self) -> None: + """Test getting query parameters.""" + + class QueryParams(msgspec.Struct): + param1: Annotated[ + str, + msgspec.Meta(extra={"query_param": {"field_name": "p1"}}), + ] + param2: Annotated[ + int, + msgspec.Meta(extra={"query_param": {"field_name": "p2"}}), + ] + + query_params = QueryParams(param1="value1", param2=42) + result = utils.get_query_params(QueryParams, query_params) + assert result == {"p1": ["value1"], "p2": ["42"]} + + +class TestGetHeaders: + """Tests for get_headers.""" + + def test_get_headers(self) -> None: + """Test getting headers.""" + + class Headers(msgspec.Struct): + header1: Annotated[ + str, + msgspec.Meta(extra={"header": {"field_name": "X-Header-1"}}), + ] + header2: Annotated[ + int, + msgspec.Meta(extra={"header": {"field_name": "X-Header-2"}}), + ] + + headers = Headers(header1="value1", header2=42) + result = utils.get_headers(headers) + assert result == {"X-Header-1": "value1", "X-Header-2": "42"} + + +class TestSerializeRequestBody: + """Tests for serialize_request_body.""" + + def test_serialize_request_body_json(self) -> None: + """Test serializing request body to JSON.""" + + class RequestBody(msgspec.Struct): + data: Annotated[ + dict[str, Any], + msgspec.Meta( + extra={"request": {"media_type": "application/json"}}, + ), + ] + + request = RequestBody(data={"key": "value"}) + result = utils.serialize_request_body(request, "data") + assert result == ("application/json", '{"key":"value"}', None) + + +class TestSerializeContentType: + """Tests for serialize_content_type.""" + + def test_serialize_content_type_json(self) -> None: + """Test serializing content type to JSON.""" + result = utils.serialize_content_type( + "data", + dict, + "application/json", + {"key": "value"}, # type: ignore[arg-type] + ) + assert result == ("application/json", '{"key":"value"}', None) + + def test_serialize_content_type_form(self) -> None: + """Test serializing content type to form data.""" + result = utils.serialize_content_type( + "data", + dict, + "application/x-www-form-urlencoded", + {"key": "value"}, # type: ignore[arg-type] + ) + assert result == ("application/x-www-form-urlencoded", {"key": ["value"]}, None) + + +class TestSerializeFormData: + """Tests for serialize_form_data.""" + + def test_serialize_form_data_struct(self) -> None: + """Test serializing form data for Struct.""" + + class FormData(msgspec.Struct): + field1: Annotated[ + str, + msgspec.Meta(extra={"form": {"field_name": "f1"}}), + ] + field2: Annotated[ + int, + msgspec.Meta(extra={"form": {"field_name": "f2"}}), + ] + + form_data = FormData(field1="value1", field2=42) + result = utils.serialize_form_data("data", form_data) + assert result == {"f1": ["value1"], "f2": ["42"]} + + def test_serialize_form_data_dict(self) -> None: + """Test serializing form data for dict.""" + result = utils.serialize_form_data("data", {"key1": "value1", "key2": 42}) + assert result == {"key1": ["value1"], "key2": ["42"]} + + +class TestMarshalJson: + """Tests for marshal_json.""" + + def test_marshal_json(self) -> None: + """Test marshalling JSON.""" + + class TestStruct(msgspec.Struct): + field1: str + field2: int + + data = TestStruct(field1="value", field2=42) + result = utils.marshal_json(data, TestStruct) + assert result == '{"field1":"value","field2":42}' + + +class TestMatchContentType: + """Tests for match_content_type.""" + + def test_match_content_type_exact(self) -> None: + """Test matching content type exactly.""" + assert utils.match_content_type("application/json", "application/json") is True + + def test_match_content_type_wildcard(self) -> None: + """Test matching content type with wildcard.""" + assert utils.match_content_type("application/json", "*/*") is True + + def test_match_content_type_partial(self) -> None: + """Test matching content type partially.""" + assert utils.match_content_type("application/json", "application/*") is True + + +class TestMatchStatusCodes: + """Tests for match_status_codes.""" + + def test_match_status_codes_exact(self) -> None: + """Test matching status codes exactly.""" + assert utils.match_status_codes(["200", "201"], 200) is True + + def test_match_status_codes_range(self) -> None: + """Test matching status codes with range.""" + assert utils.match_status_codes(["2XX"], 201) is True + + def test_match_status_codes_no_match(self) -> None: + """Test matching status codes with no match.""" + assert utils.match_status_codes(["200", "201"], 404) is False + + +class TestValToString: + """Tests for _val_to_string.""" + + def test_val_to_string_bool(self) -> None: + """Test converting bool to string.""" + assert utils._val_to_string(True) == "true" # noqa: FBT003 + assert utils._val_to_string(False) == "false" # noqa: FBT003 + + def test_val_to_string_datetime(self) -> None: + """Test converting datetime to string.""" + dt_val = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + assert utils._val_to_string(dt_val) == "2023-01-01T12:00:00Z" + + def test_val_to_string_enum(self) -> None: + """Test converting Enum to string.""" + + class TestEnum(enum.Enum): + VALUE = "test_value" + + assert utils._val_to_string(TestEnum.VALUE) == "test_value" + + def test_val_to_string_other(self) -> None: + """Test converting other types to string.""" + assert utils._val_to_string(42) == "42" + assert utils._val_to_string("test") == "test" + assert utils._val_to_string(Decimal("3.14")) == "3.14" + + +class TestRemoveSuffix: + """Tests for remove_suffix.""" + + def test_remove_suffix_present(self) -> None: + """Test removing suffix when present.""" + assert utils.remove_suffix("test_string_suffix", "_suffix") == "test_string" + + def test_remove_suffix_not_present(self) -> None: + """Test removing suffix when not present.""" + assert utils.remove_suffix("test_string", "_suffix") == "test_string" + + def test_remove_suffix_empty(self) -> None: + """Test removing empty suffix.""" + assert utils.remove_suffix("test_string", "") == "test_string" From ba01cf6edfcb4de418f6f6671840aa464560fd9a Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 20:21:21 +0200 Subject: [PATCH 25/32] =?UTF-8?q?=F0=9F=90=9B=20fix:=20remove=20class=20va?= =?UTF-8?q?rs=20on=20security=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 11d32b3..a3fadf4 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -10,7 +10,6 @@ from enum import Enum from typing import ( Any, - ClassVar, Protocol, Union, cast, @@ -30,12 +29,16 @@ class SecurityClient: """Client with security settings.""" - client: httpx.Client | None - query_params: ClassVar[dict[str, str]] = {} - headers: ClassVar[dict[str, str]] = {} - - def __init__(self, client: httpx.Client | None = None, timeout: int = 60) -> None: + def __init__( + self, + client: httpx.Client | None = None, + query_params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + timeout: int = 60, + ) -> None: self.client = client + self.query_params = query_params or {} + self.headers = headers or {} self.timeout = timeout self.limits = httpx.Limits(max_keepalive_connections=1, max_connections=1) From 87f1467fdee2d2a0696dae64762367027dcab600 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 20:21:36 +0200 Subject: [PATCH 26/32] =?UTF-8?q?=F0=9F=92=9A=20ci:=20remove=20always=20fr?= =?UTF-8?q?om=20sonarqube?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pythonpackage.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 61b46f0..8674143 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -93,7 +93,6 @@ jobs: sonarcloud: needs: [python-test] - if: ${{ always() }} name: SonarCloud runs-on: ubuntu-latest steps: From 9af9d805a96fdb9e0583d61152f29cf2343dd77b Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 23:37:01 +0200 Subject: [PATCH 27/32] =?UTF-8?q?=F0=9F=90=9B=20fix:=20handle=20annotated?= =?UTF-8?q?=20file=20fields?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/coinapi/utils/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index a3fadf4..e06e923 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -9,6 +9,7 @@ from email.message import Message from enum import Enum from typing import ( + Annotated, Any, Protocol, Union, @@ -645,6 +646,9 @@ def serialize_request_body( if field.name == request_field_name ).type + if get_origin(request_type) is Annotated: + request_type = get_args(request_type)[0] + if request_val is None and is_optional_type(request_type): return None, None, None From 46c15db3acbcde1713b96114328d99b2e8e13cd6 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 23:37:15 +0200 Subject: [PATCH 28/32] =?UTF-8?q?=E2=9C=85=20test(utils):=20add=20more=20t?= =?UTF-8?q?ests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 190 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 11ff062..b04d49c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -106,6 +106,28 @@ class Security(msgspec.Struct): expected_auth = base64.b64encode(b"user:pass").decode() assert client.headers == {"Authorization": f"Basic {expected_auth}"} + def test_configure_security_client_with_bearer_token(self) -> None: + """Test configuring security client with bearer token.""" + + class Security(msgspec.Struct): + bearer_token: Annotated[ + str, + msgspec.Meta( + extra={ + "security": { + "scheme": True, + "type": "http", + "sub_type": "bearer", + "field_name": "Authorization", + }, + }, + ), + ] + + security = Security(bearer_token="test_token") # noqa: S106 + client = utils.configure_security_client(None, security) + assert client.headers == {"Authorization": "Bearer test_token"} + class TestPathParamHandlers: """Tests for path param handlers.""" @@ -374,6 +396,19 @@ class QueryParams(msgspec.Struct): result = utils.process_query_param(field, {"x": 1, "y": 2}, None) assert result == {"param[x]": ["1"], "param[y]": ["2"]} + def test_process_query_param_serialization(self) -> None: + """Test processing query param with serialization.""" + + class QueryParams(msgspec.Struct): + param: Annotated[ + dict[str, Any], + msgspec.Meta(extra={"query_param": {"serialization": "json"}}), + ] + + field = msgspec.structs.fields(QueryParams)[0] + result = utils.process_query_param(field, {"key": "value"}, None) + assert result == {"param": ['{"key":"value"}']} + class TestGetQueryParams: """Tests for get_query_params.""" @@ -435,6 +470,19 @@ class RequestBody(msgspec.Struct): result = utils.serialize_request_body(request, "data") assert result == ("application/json", '{"key":"value"}', None) + def test_serialize_request_body_optional(self) -> None: + """Test serializing request body with optional field.""" + + class RequestBody(msgspec.Struct): + data: Annotated[ + dict[str, Any] | None, + msgspec.Meta(extra={"request": {"media_type": "application/json"}}), + ] + + request = RequestBody(data=None) + result = utils.serialize_request_body(request, "data") + assert result == (None, None, None) + class TestSerializeContentType: """Tests for serialize_content_type.""" @@ -459,6 +507,37 @@ def test_serialize_content_type_form(self) -> None: ) assert result == ("application/x-www-form-urlencoded", {"key": ["value"]}, None) + def test_serialize_content_type_multipart(self) -> None: + """Test serializing content type to multipart form data.""" + + class MultipartData(msgspec.Struct): + file: Annotated[ + File, + msgspec.Meta( + extra={ + "multipart_form": { + "file": True, + "content": True, + "field_name": "file", + }, + }, + ), + ] + + data = MultipartData(file=File(filename="test.txt", content=b"content")) + result = utils.serialize_content_type( + "data", + MultipartData, + "multipart/form-data", + data, + ) + assert result[0] == "multipart/form-data" + assert isinstance(result[2], list) + assert len(result[2]) == 1 + assert result[2][0][0] == "file" + assert result[2][0][1][0] == "test.txt" + assert result[2][0][1][1] == b"content" + class TestSerializeFormData: """Tests for serialize_form_data.""" @@ -575,3 +654,114 @@ def test_remove_suffix_not_present(self) -> None: def test_remove_suffix_empty(self) -> None: """Test removing empty suffix.""" assert utils.remove_suffix("test_string", "") == "test_string" + + +class TestSecurityClientSend: + """Tests for SecurityClient.send.""" + + def test_send_with_query_params_and_headers(self) -> None: + """Test sending request with query parameters and headers.""" + client = utils.SecurityClient( + query_params={"key": "value"}, + headers={"X-Test": "test"}, + ) + request = httpx.Request("GET", "https://example.com") + response = client.send(request) + assert "key=value" in str(response.request.url) + assert response.request.headers["X-Test"] == "test" + + +class TestGetQueryParamHandler: + """Tests for get_query_param_handler.""" + + def test_get_query_param_handler_form(self) -> None: + """Test getting query param handler for form.""" + handler = utils.get_query_param_handler("form") + assert isinstance(handler, utils.FormQueryParamHandler) + + def test_get_query_param_handler_deep_object(self) -> None: + """Test getting query param handler for deep object.""" + handler = utils.get_query_param_handler("deepObject") + assert isinstance(handler, utils.DeepObjectQueryParamHandler) + + def test_get_query_param_handler_pipe_delimited(self) -> None: + """Test getting query param handler for pipe delimited.""" + handler = utils.get_query_param_handler("pipeDelimited") + assert isinstance(handler, utils.PipeDelimitedQueryParamHandler) + + +class File(msgspec.Struct): + """File struct for testing serialize_multipart_form.""" + + filename: Annotated[ + str, + msgspec.Meta(extra={"multipart_form": {"field_name": "filename"}}), + ] + content: Annotated[ + bytes, + msgspec.Meta( + extra={ + "multipart_form": {"field_name": "content", "content": True}, + }, + ), + ] + + +class TestSerializeMultipartForm: + """Tests for serialize_multipart_form.""" + + def test_serialize_multipart_form(self) -> None: + """Test serializing multipart form data.""" + + class MultipartRequest(msgspec.Struct): + file: Annotated[ + File, + msgspec.Meta(extra={"multipart_form": {"file": True}}), + ] + + request = MultipartRequest( + file=File(filename="test.txt", content=b"file content"), + ) + + media_type, _, form = utils.serialize_multipart_form( + "multipart/form-data", + request, + ) + assert media_type == "multipart/form-data" + assert len(form) == 1 + assert form[0][0] == "file" + assert form[0][1][0] == "test.txt" + assert form[0][1][1] == b"file content" + + +class TestMultipartFormSerializer: + """Tests for MultipartFormSerializer.""" + + def test_serialize_file_field(self) -> None: + """Test serializing file field.""" + serializer = utils.MultipartFormSerializer() + field = utils.MultipartFormField( + name="file", + value=File(filename="test.txt", content=b"file content"), + metadata={"file": True}, + ) + result = serializer._get_serializer(field).serialize(field) + assert result == [["file", ["test.txt", b"file content"]]] + + def test_serialize_json_field(self) -> None: + """Test serializing JSON field.""" + serializer = utils.MultipartFormSerializer() + field = utils.MultipartFormField( + name="json_data", + value={"key": "value"}, + metadata={"json": True}, + ) + result = serializer._get_serializer(field).serialize(field) + assert result == [["json_data", [None, '{"key":"value"}', "application/json"]]] + + def test_serialize_regular_field(self) -> None: + """Test serializing regular field.""" + serializer = utils.MultipartFormSerializer() + field = utils.MultipartFormField(name="text", value="content", metadata={}) + result = serializer._get_serializer(field).serialize(field) + assert result == [["text", [None, "content"]]] From 194deb21256e8a2b9de1378ad913bcac48b8f1aa Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 23:57:30 +0200 Subject: [PATCH 29/32] =?UTF-8?q?=E2=9C=85=20test(base):=20add=20some=20te?= =?UTF-8?q?sts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_base.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/test_base.py diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..02c7296 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,53 @@ +"""Tests for the operations base.""" + +import httpx +import pytest + +from coinapi._hooks import SDKHooks +from coinapi.base import Base +from coinapi.config import CoinAPIConfig +from coinapi.models.errors import CoinAPIError + + +@pytest.fixture(name="config") +def config_fixture() -> CoinAPIConfig: + """Return a CoinAPIConfig instance.""" + config = CoinAPIConfig(None) + config._hooks = SDKHooks() + return config + + +def test_handle_request_error(config: CoinAPIConfig) -> None: + """Test handle request error.""" + base = Base(config) + hook_ctx = base._create_hook_context("test_operation") + + with pytest.raises(ValueError, match="Test error"): + base._handle_request_error(hook_ctx, ValueError("Test error")) + + +def test_handle_error_response(config: CoinAPIConfig) -> None: + """Test handle error response.""" + base = Base(config) + + response = httpx.Response(400, text="Bad Request") + with pytest.raises(CoinAPIError): + base._handle_error_response(response) + + +def test_set_response_content_text_plain(config: CoinAPIConfig) -> None: + """Test set response content text plain.""" + base = Base(config) + + class MockResponse: + content_plain = None + + res = MockResponse() + http_res = httpx.Response( + 200, + text="Test content", + headers={"Content-Type": "text/plain"}, + ) + base._set_response_content(res, http_res, "text/plain", type(res)) # type: ignore[type-var] + + assert res.content_plain == "Test content" From ba3bfc4c3f69ed72314952e550a5341436bc3a90 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sat, 12 Oct 2024 23:57:39 +0200 Subject: [PATCH 30/32] =?UTF-8?q?=E2=9C=85=20test(utils):=20more=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 91 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index b04d49c..f8a878e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -37,6 +37,21 @@ def test_send_without_client(self) -> None: response = client.send(request) assert isinstance(response, httpx.Response) + def test_parse_security_scheme_value_api_key_header(self) -> None: + """Test parsing security scheme value for API key in header.""" + client = utils.SecurityClient() + scheme_metadata = {"type": "apiKey", "sub_type": "header"} + security_metadata = {"field_name": "X-API-Key"} + + utils._parse_security_scheme_value( + client, + scheme_metadata, + security_metadata, + "test_api_key", + ) + + assert client.headers["X-API-Key"] == "test_api_key" + class TestConfigureSecurityClient: """Tests for configure_security_client.""" @@ -409,6 +424,30 @@ class QueryParams(msgspec.Struct): result = utils.process_query_param(field, {"key": "value"}, None) assert result == {"param": ['{"key":"value"}']} + def test_populate_form_struct(self) -> None: + """Test populating form for Struct.""" + + class TestStruct(msgspec.Struct): + field1: Annotated[ + str, + msgspec.Meta(extra={"query_param": {"field_name": "f1"}}), + ] + field2: Annotated[ + int, + msgspec.Meta(extra={"query_param": {"field_name": "f2"}}), + ] + + obj = TestStruct(field1="test", field2=42) + result = utils._populate_form( + "test_field", + obj, + utils._get_query_param_field_name, + ",", + explode=True, + ) + + assert result == {"f1": ["test"], "f2": ["42"]} + class TestGetQueryParams: """Tests for get_query_params.""" @@ -451,6 +490,24 @@ class Headers(msgspec.Struct): result = utils.get_headers(headers) assert result == {"X-Header-1": "value1", "X-Header-2": "42"} + def test_serialize_header_struct(self) -> None: + """Test serializing header for Struct.""" + + class TestStruct(msgspec.Struct): + field1: Annotated[ + str, + msgspec.Meta(extra={"header": {"field_name": "X-Field-1"}}), + ] + field2: Annotated[ + int, + msgspec.Meta(extra={"header": {"field_name": "X-Field-2"}}), + ] + + obj = TestStruct(field1="test", field2=42) + result = utils._serialize_header(obj, explode=True) + + assert result == "X-Field-1=test,X-Field-2=42" + class TestSerializeRequestBody: """Tests for serialize_request_body.""" @@ -733,6 +790,40 @@ class MultipartRequest(msgspec.Struct): assert form[0][1][0] == "test.txt" assert form[0][1][1] == b"file content" + def test_serialize_multipart_form_with_multiple_fields(self) -> None: + """Test serializing multipart form data with multiple fields.""" + + class TestRequest(msgspec.Struct): + file_field: Annotated[ + File, + msgspec.Meta(extra={"multipart_form": {"file": True}}), + ] + json_field: Annotated[ + dict[str, Any], + msgspec.Meta(extra={"multipart_form": {"json": True}}), + ] + regular_field: Annotated[ + str, + msgspec.Meta(extra={"multipart_form": {"content": True}}), + ] + + request = TestRequest( + file_field=File(filename="test.txt", content=b"test content"), + json_field={"key": "value"}, + regular_field="test", + ) + + media_type, _, form = utils.serialize_multipart_form( + "multipart/form-data", + request, + ) + + assert media_type == "multipart/form-data" + assert len(form) == 3 + assert any(item[0] == "file_field" for item in form) + assert any(item[0] == "json_field" for item in form) + assert any(item[0] == "regular_field" for item in form) + class TestMultipartFormSerializer: """Tests for MultipartFormSerializer.""" From 708a7222e1c9a808bf798cb9cb4f3d251086592f Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sun, 13 Oct 2024 00:13:56 +0200 Subject: [PATCH 31/32] =?UTF-8?q?=E2=9C=85=20test(utils):=20more=20coverag?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index f8a878e..345fe65 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -80,6 +80,8 @@ class Security(msgspec.Struct): }, ), ] + field_with_no_meta: str = "test" + none_field: str | None = None security = Security(api_key="test_key") client = utils.configure_security_client(None, security) @@ -143,6 +145,41 @@ class Security(msgspec.Struct): client = utils.configure_security_client(None, security) assert client.headers == {"Authorization": "Bearer test_token"} + def test_configure_security_client_with_option(self) -> None: + """Test configuring security client with security option.""" + + class SecurityOption(msgspec.Struct): + api_key: Annotated[ + str, + msgspec.Meta( + extra={ + "security": { + "scheme": True, + "type": "apiKey", + "sub_type": "header", + "field_name": "X-API-Key", + }, + }, + ), + ] + other_field: str = "test" + + class Security(msgspec.Struct): + option: Annotated[ + SecurityOption, + msgspec.Meta( + extra={ + "security": { + "option": True, + }, + }, + ), + ] + + security = Security(option=SecurityOption(api_key="test_option_key")) + client = utils.configure_security_client(None, security) + assert client.headers == {"X-API-Key": "test_option_key"} + class TestPathParamHandlers: """Tests for path param handlers.""" From dd425ccf8eff51537213c329930d9214be11c113 Mon Sep 17 00:00:00 2001 From: ljnsn Date: Sun, 13 Oct 2024 00:22:21 +0200 Subject: [PATCH 32/32] =?UTF-8?q?=E2=9C=85=20test(utils):=20more=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 135 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 345fe65..3bee500 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,6 +8,7 @@ import httpx import msgspec +import pytest from coinapi.utils import utils @@ -52,6 +53,77 @@ def test_parse_security_scheme_value_api_key_header(self) -> None: assert client.headers["X-API-Key"] == "test_api_key" + def test_parse_security_scheme_value_oauth2(self) -> None: + """Test parsing security scheme value for OAuth2.""" + client = utils.SecurityClient() + scheme_metadata = {"type": "oauth2", "sub_type": "other"} + security_metadata = {"field_name": "Authorization"} + utils._parse_security_scheme_value( + client, + scheme_metadata, + security_metadata, + "test_token", + ) + assert client.headers["Authorization"] == "Bearer test_token" + + def test_parse_security_scheme_value_open_id_connect(self) -> None: + """Test parsing security scheme value for OpenID Connect.""" + client = utils.SecurityClient() + scheme_metadata = {"type": "openIdConnect"} + security_metadata = {"field_name": "Authorization"} + utils._parse_security_scheme_value( + client, + scheme_metadata, + security_metadata, + "test_token", + ) + assert client.headers["Authorization"] == "Bearer test_token" + + def test_parse_security_scheme_value_api_key_query(self) -> None: + """Test parsing security scheme value for API key in query.""" + client = utils.SecurityClient() + scheme_metadata = {"type": "apiKey", "sub_type": "query"} + security_metadata = {"field_name": "api_key"} + utils._parse_security_scheme_value( + client, + scheme_metadata, + security_metadata, + "test_key", + ) + assert client.query_params["api_key"] == "test_key" + + def test_parse_security_scheme_value_unsupported(self) -> None: + """Test parsing security scheme value for unsupported type.""" + client = utils.SecurityClient() + scheme_metadata = {"type": "unsupported"} + security_metadata = {"field_name": "test"} + with pytest.raises(ValueError, match="not supported"): + utils._parse_security_scheme_value( + client, + scheme_metadata, + security_metadata, + "test", + ) + + +class TestUtilityFunctions: + """Tests for utility functions.""" + + def test_apply_bearer(self) -> None: + """Test applying bearer token.""" + assert utils._apply_bearer("test_token") == "Bearer test_token" + assert utils._apply_bearer("Bearer test_token") == "Bearer test_token" + + def test_get_metadata(self) -> None: + """Test getting metadata from a field.""" + + class TestStruct(msgspec.Struct): + field: Annotated[str, msgspec.Meta(extra={"test": "metadata"})] + + field_info = msgspec.structs.fields(TestStruct)[0] + metadata = utils.get_metadata(field_info) + assert metadata == {"test": "metadata"} + class TestConfigureSecurityClient: """Tests for configure_security_client.""" @@ -281,6 +353,16 @@ def test_serialize_param_non_json(self) -> None: result = utils.serialize_param(42, {}, int, "test") assert result == {} + def test_serialize_param_unsupported(self) -> None: + """Test serializing param with unsupported serialization.""" + result = utils.serialize_param( + "test", + {"serialization": "unsupported"}, + str, + "test", + ) + assert result == {} + class TestReplaceUrlPlaceholder: """Tests for replace_url_placeholder.""" @@ -339,6 +421,17 @@ class PathParams(msgspec.Struct): result = utils.get_param_value(field, None, gbls) assert result == "global_value" + def test_get_param_value_from_globals_none(self) -> None: + """Test getting param value from globals when it's None.""" + + class PathParams(msgspec.Struct): + param: str + + field = msgspec.structs.fields(PathParams)[0] + gbls = {"parameters": {"pathParam": {"param": None}}} + result = utils.get_param_value(field, None, gbls) + assert result is None + class TestGenerateUrl: """Tests for generate_url.""" @@ -485,6 +578,21 @@ class TestStruct(msgspec.Struct): assert result == {"f1": ["test"], "f2": ["42"]} + def test_process_query_param_pipe_delimited(self) -> None: + """Test processing query param with pipeDelimited style.""" + + class QueryParams(msgspec.Struct): + param: Annotated[ + list[str], + msgspec.Meta( + extra={"query_param": {"style": "pipeDelimited", "explode": False}}, + ), + ] + + field = msgspec.structs.fields(QueryParams)[0] + result = utils.process_query_param(field, ["a", "b", "c"], None) + assert result == {"param": ["a|b|c"]} + class TestGetQueryParams: """Tests for get_query_params.""" @@ -632,6 +740,16 @@ class MultipartData(msgspec.Struct): assert result[2][0][1][0] == "test.txt" assert result[2][0][1][1] == b"content" + def test_serialize_content_type_invalid(self) -> None: + """Test serializing content type with invalid media type.""" + with pytest.raises(ValueError, match="invalid request body type"): + utils.serialize_content_type( + "field", + dict, + "invalid/media-type", + {"key": "value"}, # type: ignore[arg-type] + ) + class TestSerializeFormData: """Tests for serialize_form_data.""" @@ -861,6 +979,23 @@ class TestRequest(msgspec.Struct): assert any(item[0] == "json_field" for item in form) assert any(item[0] == "regular_field" for item in form) + def test_serialize_multipart_form_invalid_file(self) -> None: + """Test serializing multipart form with invalid file.""" + + class InvalidFile(msgspec.Struct): + invalid: str + + class MultipartRequest(msgspec.Struct): + file: Annotated[ + InvalidFile, + msgspec.Meta(extra={"multipart_form": {"file": True}}), + ] + + request = MultipartRequest(file=InvalidFile(invalid="test")) + + with pytest.raises(ValueError, match="Invalid multipart/form-data file"): + utils.serialize_multipart_form("multipart/form-data", request) + class TestMultipartFormSerializer: """Tests for MultipartFormSerializer."""