Skip to content

Commit

Permalink
♻️ refactor: serialize multipart form
Browse files Browse the repository at this point in the history
  • Loading branch information
ljnsn committed Oct 12, 2024
1 parent 5a07802 commit 4f4fabd
Showing 1 changed file with 122 additions and 56 deletions.
178 changes: 122 additions & 56 deletions src/coinapi/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,62 +687,6 @@ def serialize_content_type(
raise ValueError(msg)


def serialize_multipart_form( # noqa: PLR0912, C901
media_type: str,
request: msgspec.Struct,
) -> tuple[str, Any, list[list[Any]]]:
"""Serialize a multipart form."""
form: list[list[Any]] = []
request_fields = msgspec.structs.fields(request)

for field in request_fields:
val = getattr(request, field.name)
if val is None:
continue

field_metadata = get_metadata(field).get("multipart_form")
if not field_metadata:
continue

if field_metadata.get("file") is True:
file_fields = msgspec.structs.fields(val)

file_name = ""
field_name = ""
content = b""

for file_field in file_fields:
file_metadata = get_metadata(file_field).get("multipart_form")
if file_metadata is None:
continue

if file_metadata.get("content") is True:
content = getattr(val, file_field.name)
else:
field_name = file_metadata.get("field_name", file_field.name)
file_name = getattr(val, file_field.name)
if field_name == "" or file_name == "" or content == b"":
raise ValueError("invalid multipart/form-data file")

form.append([field_name, [file_name, content]])
elif field_metadata.get("json") is True:
to_append = [
field_metadata.get("field_name", field.name),
[None, marshal_json(val, field.type), "application/json"],
]
form.append(to_append)
else:
field_name = field_metadata.get("field_name", field.name)
if isinstance(val, list):
for value in val:
if value is None:
continue
form.append([field_name + "[]", [None, _val_to_string(value)]])
else:
form.append([field_name, [None, _val_to_string(val)]])
return media_type, None, form


def serialize_dict(
original: dict[str, Any],
explode: bool, # noqa: FBT001
Expand All @@ -769,6 +713,128 @@ def serialize_dict(
return existing


class MultipartFormField:
"""Represents a field in a multipart form."""

def __init__(self, name: str, value: Any, metadata: dict[str, Any]) -> None:
self.name = name
self.value = value
self.metadata = metadata


class FieldSerializer(Protocol):
"""Protocol for field serializers."""

def serialize(self, field: MultipartFormField) -> list[Any]:
"""Serialize a field."""
...


class FileFieldSerializer:
"""Serializer for file fields."""

def serialize(self, field: MultipartFormField) -> list[Any]:
"""Serialize a file field."""
file_fields = msgspec.structs.fields(field.value)
file_name = ""
content = b""

for file_field in file_fields:
file_metadata = get_metadata(file_field).get("multipart_form")
if file_metadata is None:
continue

if file_metadata.get("content") is True:
content = getattr(field.value, file_field.name)
else:
file_name = getattr(field.value, file_field.name)

if not file_name or not content:
raise ValueError("Invalid multipart/form-data file")

return [[field.name, [file_name, content]]]


class JsonFieldSerializer:
"""Serializer for JSON fields."""

def serialize(self, field: MultipartFormField) -> list[Any]:
"""Serialize a JSON field."""
return [
[
field.metadata.get("field_name", field.name),
[
None,
marshal_json(field.value, type(field.value)),
"application/json",
],
],
]


class RegularFieldSerializer:
"""Serializer for regular fields."""

def serialize(self, field: MultipartFormField) -> list[Any]:
"""Serialize a regular field."""
field_name = field.metadata.get("field_name", field.name)
if isinstance(field.value, list):
return [
[f"{field_name}[]", [None, _val_to_string(value)]]
for value in field.value
if value is not None
]
return [[field_name, [None, _val_to_string(field.value)]]]


class MultipartFormSerializer:
"""Serializes a multipart form."""

def __init__(self) -> None:
self.serializers: dict[str, FieldSerializer] = {
"file": FileFieldSerializer(),
"json": JsonFieldSerializer(),
"regular": RegularFieldSerializer(),
}

def serialize(self, request: msgspec.Struct) -> tuple[str, Any, list[list[Any]]]:
"""Serialize the entire multipart form."""
form: list[list[Any]] = []
for field in self._get_fields(request):
serializer = self._get_serializer(field)
form.extend(serializer.serialize(field))
return "multipart/form-data", None, form

def _get_fields(self, request: msgspec.Struct) -> list[MultipartFormField]:
"""Extract fields from the request."""
fields = []
for field in msgspec.structs.fields(request):
value = getattr(request, field.name)
if value is None:
continue
metadata = get_metadata(field).get("multipart_form", {})
if metadata:
fields.append(MultipartFormField(field.name, value, metadata))
return fields

def _get_serializer(self, field: MultipartFormField) -> FieldSerializer:
"""Get the appropriate serializer for a field."""
if field.metadata.get("file") is True:
return self.serializers["file"]
if field.metadata.get("json") is True:
return self.serializers["json"]
return self.serializers["regular"]


def serialize_multipart_form(
_media_type: str,
request: msgspec.Struct,
) -> tuple[str, Any, list[list[Any]]]:
"""Serialize a multipart form."""
serializer = MultipartFormSerializer()
return serializer.serialize(request)


def serialize_form_data(field_name: str, data: Any) -> dict[str, Any]:
"""Serialize form data."""
form: dict[str, list[str]] = {}
Expand Down

0 comments on commit 4f4fabd

Please sign in to comment.