From 913bc833d5b2eb8c6374218d927d8a917dabd164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Tue, 16 Jul 2024 18:32:28 -0600 Subject: [PATCH] refactor: Implement msgspec encoding --- noxfile.py | 10 +- poetry.lock | 55 ++++++- pyproject.toml | 2 + singer_sdk/_singerlib/encoding/_msgspec.py | 172 +++++++++++++++++++++ tests/_singerlib/_encoding/test_msgspec.py | 41 +++++ tests/core/test_io.py | 20 ++- 6 files changed, 293 insertions(+), 7 deletions(-) create mode 100644 singer_sdk/_singerlib/encoding/_msgspec.py create mode 100644 tests/_singerlib/_encoding/test_msgspec.py diff --git a/noxfile.py b/noxfile.py index 0b47fa639..f1b5c7172 100644 --- a/noxfile.py +++ b/noxfile.py @@ -51,7 +51,7 @@ def mypy(session: Session) -> None: """Check types with mypy.""" args = session.posargs or ["singer_sdk"] - session.install(".[faker,jwt,parquet,s3,testing]") + session.install(".[faker,jwt,msgspec,parquet,s3,testing]") session.install(*typing_dependencies) session.run("mypy", *args) if not session.posargs: @@ -61,7 +61,7 @@ def mypy(session: Session) -> None: @session(python=python_versions) def tests(session: Session) -> None: """Execute pytest tests and compute coverage.""" - session.install(".[faker,jwt,parquet,s3]") + session.install(".[faker,jwt,msgspec,parquet,s3]") session.install(*test_dependencies) sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION") @@ -94,7 +94,7 @@ def tests(session: Session) -> None: @session(python=main_python_version) def benches(session: Session) -> None: """Run benchmarks.""" - session.install(".[jwt,s3]") + session.install(".[jwt,msgspec,s3]") session.install(*test_dependencies) sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION") if sqlalchemy_version: @@ -114,7 +114,7 @@ def benches(session: Session) -> None: @session(name="deps", python=python_versions) def dependencies(session: Session) -> None: """Check issues with dependencies.""" - session.install(".[s3,testing]") + session.install(".[msgspec,s3,testing]") session.install("deptry") session.run("deptry", "singer_sdk", *session.posargs) @@ -124,7 +124,7 @@ def update_snapshots(session: Session) -> None: """Update pytest snapshots.""" args = session.posargs or ["-m", "snapshot"] - session.install(".[faker,jwt,parquet]") + session.install(".[faker,jwt,msgspec,parquet]") session.install(*test_dependencies) session.run("pytest", "--snapshot-update", *args) diff --git a/poetry.lock b/poetry.lock index 4cea6f749..19b89a6a9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1116,6 +1116,58 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "msgspec" +version = "0.18.6" +description = "A fast serialization and validation library, with builtin support for JSON, MessagePack, YAML, and TOML." +optional = true +python-versions = ">=3.8" +files = [ + {file = "msgspec-0.18.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77f30b0234eceeff0f651119b9821ce80949b4d667ad38f3bfed0d0ebf9d6d8f"}, + {file = "msgspec-0.18.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a76b60e501b3932782a9da039bd1cd552b7d8dec54ce38332b87136c64852dd"}, + {file = "msgspec-0.18.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06acbd6edf175bee0e36295d6b0302c6de3aaf61246b46f9549ca0041a9d7177"}, + {file = "msgspec-0.18.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40a4df891676d9c28a67c2cc39947c33de516335680d1316a89e8f7218660410"}, + {file = "msgspec-0.18.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a6896f4cd5b4b7d688018805520769a8446df911eb93b421c6c68155cdf9dd5a"}, + {file = "msgspec-0.18.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3ac4dd63fd5309dd42a8c8c36c1563531069152be7819518be0a9d03be9788e4"}, + {file = "msgspec-0.18.6-cp310-cp310-win_amd64.whl", hash = "sha256:fda4c357145cf0b760000c4ad597e19b53adf01382b711f281720a10a0fe72b7"}, + {file = "msgspec-0.18.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e77e56ffe2701e83a96e35770c6adb655ffc074d530018d1b584a8e635b4f36f"}, + {file = "msgspec-0.18.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5351afb216b743df4b6b147691523697ff3a2fc5f3d54f771e91219f5c23aaa"}, + {file = "msgspec-0.18.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232fabacef86fe8323cecbe99abbc5c02f7698e3f5f2e248e3480b66a3596b"}, + {file = "msgspec-0.18.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b524df6ea9998bbc99ea6ee4d0276a101bcc1aa8d14887bb823914d9f60d07"}, + {file = "msgspec-0.18.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:37f67c1d81272131895bb20d388dd8d341390acd0e192a55ab02d4d6468b434c"}, + {file = "msgspec-0.18.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d0feb7a03d971c1c0353de1a8fe30bb6579c2dc5ccf29b5f7c7ab01172010492"}, + {file = "msgspec-0.18.6-cp311-cp311-win_amd64.whl", hash = "sha256:41cf758d3f40428c235c0f27bc6f322d43063bc32da7b9643e3f805c21ed57b4"}, + {file = "msgspec-0.18.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d86f5071fe33e19500920333c11e2267a31942d18fed4d9de5bc2fbab267d28c"}, + {file = "msgspec-0.18.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce13981bfa06f5eb126a3a5a38b1976bddb49a36e4f46d8e6edecf33ccf11df1"}, + {file = "msgspec-0.18.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97dec6932ad5e3ee1e3c14718638ba333befc45e0661caa57033cd4cc489466"}, + {file = "msgspec-0.18.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad237100393f637b297926cae1868b0d500f764ccd2f0623a380e2bcfb2809ca"}, + {file = "msgspec-0.18.6-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db1d8626748fa5d29bbd15da58b2d73af25b10aa98abf85aab8028119188ed57"}, + {file = "msgspec-0.18.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d70cb3d00d9f4de14d0b31d38dfe60c88ae16f3182988246a9861259c6722af6"}, + {file = "msgspec-0.18.6-cp312-cp312-win_amd64.whl", hash = "sha256:1003c20bfe9c6114cc16ea5db9c5466e49fae3d7f5e2e59cb70693190ad34da0"}, + {file = "msgspec-0.18.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f7d9faed6dfff654a9ca7d9b0068456517f63dbc3aa704a527f493b9200b210a"}, + {file = "msgspec-0.18.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9da21f804c1a1471f26d32b5d9bc0480450ea77fbb8d9db431463ab64aaac2cf"}, + {file = "msgspec-0.18.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46eb2f6b22b0e61c137e65795b97dc515860bf6ec761d8fb65fdb62aa094ba61"}, + {file = "msgspec-0.18.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8355b55c80ac3e04885d72db515817d9fbb0def3bab936bba104e99ad22cf46"}, + {file = "msgspec-0.18.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9080eb12b8f59e177bd1eb5c21e24dd2ba2fa88a1dbc9a98e05ad7779b54c681"}, + {file = "msgspec-0.18.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cc001cf39becf8d2dcd3f413a4797c55009b3a3cdbf78a8bf5a7ca8fdb76032c"}, + {file = "msgspec-0.18.6-cp38-cp38-win_amd64.whl", hash = "sha256:fac5834e14ac4da1fca373753e0c4ec9c8069d1fe5f534fa5208453b6065d5be"}, + {file = "msgspec-0.18.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:974d3520fcc6b824a6dedbdf2b411df31a73e6e7414301abac62e6b8d03791b4"}, + {file = "msgspec-0.18.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fd62e5818731a66aaa8e9b0a1e5543dc979a46278da01e85c3c9a1a4f047ef7e"}, + {file = "msgspec-0.18.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7481355a1adcf1f08dedd9311193c674ffb8bf7b79314b4314752b89a2cf7f1c"}, + {file = "msgspec-0.18.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6aa85198f8f154cf35d6f979998f6dadd3dc46a8a8c714632f53f5d65b315c07"}, + {file = "msgspec-0.18.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0e24539b25c85c8f0597274f11061c102ad6b0c56af053373ba4629772b407be"}, + {file = "msgspec-0.18.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c61ee4d3be03ea9cd089f7c8e36158786cd06e51fbb62529276452bbf2d52ece"}, + {file = "msgspec-0.18.6-cp39-cp39-win_amd64.whl", hash = "sha256:b5c390b0b0b7da879520d4ae26044d74aeee5144f83087eb7842ba59c02bc090"}, + {file = "msgspec-0.18.6.tar.gz", hash = "sha256:a59fc3b4fcdb972d09138cb516dbde600c99d07c38fd9372a6ef500d2d031b4e"}, +] + +[package.extras] +dev = ["attrs", "coverage", "furo", "gcovr", "ipython", "msgpack", "mypy", "pre-commit", "pyright", "pytest", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "tomli", "tomli-w"] +doc = ["furo", "ipython", "sphinx", "sphinx-copybutton", "sphinx-design"] +test = ["attrs", "msgpack", "mypy", "pyright", "pytest", "pyyaml", "tomli", "tomli-w"] +toml = ["tomli", "tomli-w"] +yaml = ["pyyaml"] + [[package]] name = "mypy" version = "1.10.1" @@ -2667,6 +2719,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", docs = ["furo", "myst-parser", "pytest", "sphinx", "sphinx-autobuild", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinx-reredirects"] faker = ["faker"] jwt = ["PyJWT", "cryptography"] +msgspec = ["msgspec"] parquet = ["numpy", "numpy", "pyarrow"] s3 = ["fs-s3fs"] testing = ["pytest", "pytest-durations"] @@ -2674,4 +2727,4 @@ testing = ["pytest", "pytest-durations"] [metadata] lock-version = "2.0" python-versions = ">=3.8" -content-hash = "82c4b9443a3fed513d597831da8a953b3b2989e4859895939a31db30959a19d9" +content-hash = "9e1b4fa3605c59d019e3343d695ced3b7e8da28342955279f2919b305f3faeba" diff --git a/pyproject.toml b/pyproject.toml index 990fb97f0..64f9087ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ inflection = ">=0.5.1" joblib = ">=1.3.0" jsonpath-ng = ">=1.5.3" jsonschema = ">=4.16.0" +msgspec = { version = ">=0.18.6", optional = true } packaging = ">=23.1" pendulum = ">=2.1.0,<4" python-dateutil = ">=2.8.2" @@ -111,6 +112,7 @@ docs = [ "sphinx-notfound-page", "sphinx-reredirects", ] +msgspec = ["msgspec"] s3 = ["fs-s3fs"] testing = [ "pytest", diff --git a/singer_sdk/_singerlib/encoding/_msgspec.py b/singer_sdk/_singerlib/encoding/_msgspec.py new file mode 100644 index 000000000..55a3c42d6 --- /dev/null +++ b/singer_sdk/_singerlib/encoding/_msgspec.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import datetime +import decimal +import logging +import sys +import typing as t + +import msgspec + +from singer_sdk._singerlib.exceptions import InvalidInputLine + +from ._base import GenericSingerReader, GenericSingerWriter + +logger = logging.getLogger(__name__) + + +class Message(msgspec.Struct, tag_field="type", tag=str.upper): + """Singer base message.""" + + def to_dict(self): # noqa: ANN202 + return {f: getattr(self, f) for f in self.__struct_fields__} + + +class RecordMessage(Message, tag="RECORD"): + """Singer RECORD message.""" + + stream: str + record: t.Dict[str, t.Any] # noqa: UP006 + version: t.Union[int, None] = None # noqa: UP007 + time_extracted: t.Union[datetime.datetime, None] = None # noqa: UP007 + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If the time_extracted is not timezone-aware. + """ + if self.time_extracted and not self.time_extracted.tzinfo: + msg = ( + "'time_extracted' must be either None or an aware datetime (with a " + "time zone)" + ) + raise ValueError(msg) + + if self.time_extracted: + self.time_extracted = self.time_extracted.astimezone(datetime.timezone.utc) + + +class SchemaMessage(Message, tag="SCHEMA"): + """Singer SCHEMA message.""" + + stream: str + schema: t.Dict[str, t.Any] # noqa: UP006 + key_properties: t.List[str] # noqa: UP006 + bookmark_properties: t.Union[t.List[str], None] = None # noqa: UP006, UP007 + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If bookmark_properties is not a string or list of strings. + """ + if isinstance(self.bookmark_properties, (str, bytes)): + self.bookmark_properties = [self.bookmark_properties] + if self.bookmark_properties and not isinstance(self.bookmark_properties, list): + msg = "bookmark_properties must be a string or list of strings" + raise ValueError(msg) + + +class StateMessage(Message, tag="STATE"): + """Singer state message.""" + + value: t.Dict[str, t.Any] # noqa: UP006 + """The state value.""" + + +class ActivateVersionMessage(Message, tag="ACTIVATE_VERSION"): + """Singer activate version message.""" + + stream: str + """The stream name.""" + + version: int + """The version to activate.""" + + +def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401 + """Encoding type helper for non native types. + + Args: + obj: the item to be encoded + + Returns: + The object converted to the appropriate type, default is str + """ + return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj) + + +def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401 + """Decoding type helper for non native types. + + Args: + type: the type given + obj: the item to be decoded + + Returns: + The object converted to the appropriate type, default is str. + """ + return str(obj) + + +encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number") +decoder = msgspec.json.Decoder( + t.Union[ + RecordMessage, + SchemaMessage, + StateMessage, + ActivateVersionMessage, + ], + dec_hook=dec_hook, + float_hook=decimal.Decimal, +) + + +class MsgSpecReader(GenericSingerReader[str]): + """Base class for all plugins reading Singer messages as strings from stdin.""" + + default_input = sys.stdin + + def deserialize_json(self, line: str) -> dict: # noqa: PLR6301 + """Deserialize a line of json. + + Args: + line: A single line of json. + + Returns: + A dictionary of the deserialized json. + + Raises: + InvalidInputLine: If the line cannot be parsed + """ + try: + return decoder.decode(line).to_dict() # type: ignore[no-any-return] + except msgspec.DecodeError as exc: + logger.exception("Unable to parse:\n%s", line) + msg = f"Unable to parse line as JSON: {line}" + raise InvalidInputLine(msg) from exc + + +class MsgSpecWriter(GenericSingerWriter[bytes, Message]): + """Interface for all plugins writing Singer messages to stdout.""" + + def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301 + """Serialize a dictionary into a line of json. + + Args: + message: A Singer message object. + + Returns: + A string of serialized json. + """ + return encoder.encode(message) + + def write_message(self, message: Message) -> None: + """Write a message to stdout. + + Args: + message: The message to write. + """ + sys.stdout.buffer.write(self.format_message(message) + b"\n") + sys.stdout.flush() diff --git a/tests/_singerlib/_encoding/test_msgspec.py b/tests/_singerlib/_encoding/test_msgspec.py new file mode 100644 index 000000000..3fca4278c --- /dev/null +++ b/tests/_singerlib/_encoding/test_msgspec.py @@ -0,0 +1,41 @@ +from __future__ import annotations # noqa: INP001 + +import pytest + +from singer_sdk._singerlib.encoding._msgspec import dec_hook, enc_hook + + +@pytest.mark.parametrize( + "test_type,test_value,expected_value,expected_type", + [ + pytest.param( + int, + 1, + "1", + str, + id="int-to-str", + ), + ], +) +def test_dec_hook(test_type, test_value, expected_value, expected_type): + returned = dec_hook(type=test_type, obj=test_value) + returned_type = type(returned) + + assert returned == expected_value + assert returned_type == expected_type + + +@pytest.mark.parametrize( + "test_value,expected_value", + [ + pytest.param( + 1, + "1", + id="int-to-str", + ), + ], +) +def test_enc_hook(test_value, expected_value): + returned = enc_hook(obj=test_value) + + assert returned == expected_value diff --git a/tests/core/test_io.py b/tests/core/test_io.py index a48a785df..56ac80dd2 100644 --- a/tests/core/test_io.py +++ b/tests/core/test_io.py @@ -12,6 +12,7 @@ import pytest from singer_sdk._singerlib import RecordMessage +from singer_sdk._singerlib.encoding._msgspec import MsgSpecReader, MsgSpecWriter from singer_sdk._singerlib.exceptions import InvalidInputLine from singer_sdk.io_base import SingerReader, SingerWriter @@ -104,6 +105,7 @@ def test_write_message(): def bench_record(): return { "stream": "users", + "type": "RECORD", "record": { "Id": 1, "created_at": "2021-01-01T00:08:00-07:00", @@ -131,7 +133,7 @@ def test_bench_format_message(benchmark, bench_record_message): """Run benchmark for Sink._validator method validate.""" number_of_runs = 1000 - writer = SingerWriter() + writer = MsgSpecWriter() def run_format_message(): for record in itertools.repeat(bench_record_message, number_of_runs): @@ -144,6 +146,22 @@ def test_bench_deserialize_json(benchmark, bench_encoded_record): """Run benchmark for Sink._validator method validate.""" number_of_runs = 1000 + class DummyReader(MsgSpecReader): + def _process_activate_version_message(self, message_dict: dict) -> None: + pass + + def _process_batch_message(self, message_dict: dict) -> None: + pass + + def _process_record_message(self, message_dict: dict) -> None: + pass + + def _process_schema_message(self, message_dict: dict) -> None: + pass + + def _process_state_message(self, message_dict: dict) -> None: + pass + reader = DummyReader() def run_deserialize_json():