Skip to content

Commit

Permalink
Allow passing a default to Context.get
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Jan 1, 2025
1 parent feffc2e commit 6fac57d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/marshmallow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ def __exit__(self, *args, **kwargs):
self._current_context.reset(self.token)

@classmethod
def get(cls):
def get(cls, default=...):
if default is not ...:
return cls._current_context.get(default)
return cls._current_context.get()
12 changes: 2 additions & 10 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,21 +2167,13 @@ class TestContext:
def test_context_load_dump(self):
class ContextField(fields.Integer):
def _serialize(self, value, attr, obj, **kwargs):
try:
context = Context.get()
except LookupError:
pass
else:
if (context := Context.get(None)) is not None:
value *= context.get("factor", 1)
return super()._serialize(value, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
val = super()._deserialize(value, attr, data, **kwargs)
try:
context = Context.get()
except LookupError:
pass
else:
if (context := Context.get(None)) is not None:
val *= context.get("factor", 1)
return val

Expand Down

0 comments on commit 6fac57d

Please sign in to comment.