diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py index a21a027c3..45c03e49f 100644 --- a/src/marshmallow/context.py +++ b/src/marshmallow/context.py @@ -17,5 +17,7 @@ def __exit__(self, *args, **kwargs): self._current_context.reset(self.token) @classmethod - def get(cls): + def get(cls, default=...): + if default is not ...: + return cls._current_context.get(default) return cls._current_context.get() diff --git a/tests/test_schema.py b/tests/test_schema.py index d97cfa831..d6241c724 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2167,21 +2167,13 @@ class TestContext: def test_context_load_dump(self): class ContextField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs): - try: - context = Context.get() - except LookupError: - pass - else: + if (context := Context.get(None)) is not None: value *= context.get("factor", 1) return super()._serialize(value, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): val = super()._deserialize(value, attr, data, **kwargs) - try: - context = Context.get() - except LookupError: - pass - else: + if (context := Context.get(None)) is not None: val *= context.get("factor", 1) return val