diff --git a/src/coinapi/utils/utils.py b/src/coinapi/utils/utils.py index 21bfe8e..4fcbe8f 100644 --- a/src/coinapi/utils/utils.py +++ b/src/coinapi/utils/utils.py @@ -687,62 +687,6 @@ def serialize_content_type( raise ValueError(msg) -def serialize_multipart_form( # noqa: PLR0912, C901 - media_type: str, - request: msgspec.Struct, -) -> tuple[str, Any, list[list[Any]]]: - """Serialize a multipart form.""" - form: list[list[Any]] = [] - request_fields = msgspec.structs.fields(request) - - for field in request_fields: - val = getattr(request, field.name) - if val is None: - continue - - field_metadata = get_metadata(field).get("multipart_form") - if not field_metadata: - continue - - if field_metadata.get("file") is True: - file_fields = msgspec.structs.fields(val) - - file_name = "" - field_name = "" - content = b"" - - for file_field in file_fields: - file_metadata = get_metadata(file_field).get("multipart_form") - if file_metadata is None: - continue - - if file_metadata.get("content") is True: - content = getattr(val, file_field.name) - else: - field_name = file_metadata.get("field_name", file_field.name) - file_name = getattr(val, file_field.name) - if field_name == "" or file_name == "" or content == b"": - raise ValueError("invalid multipart/form-data file") - - form.append([field_name, [file_name, content]]) - elif field_metadata.get("json") is True: - to_append = [ - field_metadata.get("field_name", field.name), - [None, marshal_json(val, field.type), "application/json"], - ] - form.append(to_append) - else: - field_name = field_metadata.get("field_name", field.name) - if isinstance(val, list): - for value in val: - if value is None: - continue - form.append([field_name + "[]", [None, _val_to_string(value)]]) - else: - form.append([field_name, [None, _val_to_string(val)]]) - return media_type, None, form - - def serialize_dict( original: dict[str, Any], explode: bool, # noqa: FBT001 @@ -769,6 +713,128 @@ def serialize_dict( return existing +class MultipartFormField: + """Represents a field in a multipart form.""" + + def __init__(self, name: str, value: Any, metadata: dict[str, Any]) -> None: + self.name = name + self.value = value + self.metadata = metadata + + +class FieldSerializer(Protocol): + """Protocol for field serializers.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a field.""" + ... + + +class FileFieldSerializer: + """Serializer for file fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a file field.""" + file_fields = msgspec.structs.fields(field.value) + file_name = "" + content = b"" + + for file_field in file_fields: + file_metadata = get_metadata(file_field).get("multipart_form") + if file_metadata is None: + continue + + if file_metadata.get("content") is True: + content = getattr(field.value, file_field.name) + else: + file_name = getattr(field.value, file_field.name) + + if not file_name or not content: + raise ValueError("Invalid multipart/form-data file") + + return [[field.name, [file_name, content]]] + + +class JsonFieldSerializer: + """Serializer for JSON fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a JSON field.""" + return [ + [ + field.metadata.get("field_name", field.name), + [ + None, + marshal_json(field.value, type(field.value)), + "application/json", + ], + ], + ] + + +class RegularFieldSerializer: + """Serializer for regular fields.""" + + def serialize(self, field: MultipartFormField) -> list[Any]: + """Serialize a regular field.""" + field_name = field.metadata.get("field_name", field.name) + if isinstance(field.value, list): + return [ + [f"{field_name}[]", [None, _val_to_string(value)]] + for value in field.value + if value is not None + ] + return [[field_name, [None, _val_to_string(field.value)]]] + + +class MultipartFormSerializer: + """Serializes a multipart form.""" + + def __init__(self) -> None: + self.serializers: dict[str, FieldSerializer] = { + "file": FileFieldSerializer(), + "json": JsonFieldSerializer(), + "regular": RegularFieldSerializer(), + } + + def serialize(self, request: msgspec.Struct) -> tuple[str, Any, list[list[Any]]]: + """Serialize the entire multipart form.""" + form: list[list[Any]] = [] + for field in self._get_fields(request): + serializer = self._get_serializer(field) + form.extend(serializer.serialize(field)) + return "multipart/form-data", None, form + + def _get_fields(self, request: msgspec.Struct) -> list[MultipartFormField]: + """Extract fields from the request.""" + fields = [] + for field in msgspec.structs.fields(request): + value = getattr(request, field.name) + if value is None: + continue + metadata = get_metadata(field).get("multipart_form", {}) + if metadata: + fields.append(MultipartFormField(field.name, value, metadata)) + return fields + + def _get_serializer(self, field: MultipartFormField) -> FieldSerializer: + """Get the appropriate serializer for a field.""" + if field.metadata.get("file") is True: + return self.serializers["file"] + if field.metadata.get("json") is True: + return self.serializers["json"] + return self.serializers["regular"] + + +def serialize_multipart_form( + _media_type: str, + request: msgspec.Struct, +) -> tuple[str, Any, list[list[Any]]]: + """Serialize a multipart form.""" + serializer = MultipartFormSerializer() + return serializer.serialize(request) + + def serialize_form_data(field_name: str, data: Any) -> dict[str, Any]: """Serialize form data.""" form: dict[str, list[str]] = {}