diff --git a/pynamodb_attributes/__init__.py b/pynamodb_attributes/__init__.py index 765f8c7..fc6d21e 100644 --- a/pynamodb_attributes/__init__.py +++ b/pynamodb_attributes/__init__.py @@ -12,6 +12,7 @@ from .unicode_datetime import UnicodeDatetimeAttribute from .unicode_delimited_tuple import UnicodeDelimitedTupleAttribute from .unicode_enum import UnicodeEnumAttribute +from .unicode_protobuf_enum import UnicodeProtobufEnumAttribute from .uuid import UUIDAttribute __all__ = [ @@ -22,6 +23,7 @@ "IntegerEnumAttribute", "UnicodeDelimitedTupleAttribute", "UnicodeEnumAttribute", + "UnicodeProtobufEnumAttribute", "TimedeltaAttribute", "TimedeltaMsAttribute", "TimedeltaUsAttribute", diff --git a/pynamodb_attributes/unicode_protobuf_enum.py b/pynamodb_attributes/unicode_protobuf_enum.py index 66ee104..7d8517a 100644 --- a/pynamodb_attributes/unicode_protobuf_enum.py +++ b/pynamodb_attributes/unicode_protobuf_enum.py @@ -50,7 +50,7 @@ def __init__( enum_type: Type[_TProtobufEnum], *, unknown_value: Optional[_TProtobufEnum] = _fail, - prefix: str, + prefix: str = "", lower: bool = True, **kwargs: Any, ) -> None: diff --git a/setup.cfg b/setup.cfg index 5585f89..a8c5b1a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ universal = 1 [metadata] license_file = LICENSE name = pynamodb-attributes -version = 0.5.0 +version = 0.5.1 description = Common attributes for PynamoDB long_description = file:README.md long_description_content_type = text/markdown diff --git a/tests/unicode_protobuf_enum_attribute_test.py b/tests/unicode_protobuf_enum_attribute_test.py index 7c2882f..071ad5d 100644 --- a/tests/unicode_protobuf_enum_attribute_test.py +++ b/tests/unicode_protobuf_enum_attribute_test.py @@ -10,7 +10,7 @@ from pynamodb.models import Model from typing_extensions import assert_type -from pynamodb_attributes.unicode_protobuf_enum import UnicodeProtobufEnumAttribute +from pynamodb_attributes import UnicodeProtobufEnumAttribute from tests.connection import _connection from tests.meta import dynamodb_table_meta @@ -85,6 +85,12 @@ class MyModel(Model): prefix="SHAKE_FLAVOR_", null=True, ) + value_with_prefix = UnicodeProtobufEnumAttribute( + diner_pb2.ShakeFlavor, + unknown_value=diner_pb2.SHAKE_FLAVOR_UNKNOWN, + null=True, + lower=False, + ) map_attr = MyMapAttr(null=True) @@ -140,6 +146,7 @@ def test_serialization_unknown_value_success(uuid_key): "value": {"S": "vanilla"}, "value_upper": {"S": "VANILLA"}, "value_with_unknown": {"S": "vanilla"}, + "value_with_prefix": {"S": "SHAKE_FLAVOR_VANILLA"}, }, ), ( @@ -148,6 +155,7 @@ def test_serialization_unknown_value_success(uuid_key): "value": {"S": "chocolate"}, "value_upper": {"S": "CHOCOLATE"}, "value_with_unknown": {"S": "chocolate"}, + "value_with_prefix": {"S": "SHAKE_FLAVOR_CHOCOLATE"}, }, ), ], @@ -162,6 +170,7 @@ def test_serialization( model.value = value model.value_upper = value model.value_with_unknown = value + model.value_with_prefix = value model.save() # verify underlying storage @@ -173,6 +182,7 @@ def test_serialization( assert model.value == value assert model.value_upper == value assert model.value_with_unknown == value + assert model.value_with_prefix == value def test_map_attribute( # exercises the __deepcopy__ method