From 59472ca28899d0d8263fffd1a93599cf1f67ecb5 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 31 Jan 2024 18:45:34 -0500 Subject: [PATCH] fix: fix connect_setattr on dataclass field signals (#258) * fix: fix connect_setattr on dataclass field signals * add comment * fix lint --- src/psygnal/_group_descriptor.py | 26 ++++++++++++++++++++++---- src/psygnal/_signal.py | 3 ++- tests/test_group_descriptor.py | 26 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index b2181250..7aac8a0d 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -20,12 +20,13 @@ from ._dataclass_utils import iter_fields from ._group import SignalGroup -from ._signal import Signal +from ._signal import Signal, SignalInstance if TYPE_CHECKING: + from _weakref import ref as ref from typing_extensions import Literal - from ._signal import SignalInstance + from psygnal._weak_callback import RefErrorChoice, WeakCallback __all__ = ["is_evented", "get_evented_namespace", "SignalGroupDescriptor"] @@ -125,6 +126,18 @@ def _pick_equality_operator(type_: type | None) -> EqOperator: return operator.eq +class _DataclassFieldSignalInstance(SignalInstance): + def connect_setattr( + self, + obj: ref | object, + attr: str, + maxargs: int | None = 1, + *, + on_ref_error: RefErrorChoice = "warn", + ) -> WeakCallback[None]: + return super().connect_setattr(obj, attr, maxargs, on_ref_error=on_ref_error) + + @lru_cache(maxsize=None) def _build_dataclass_signal_group( cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None @@ -133,6 +146,7 @@ def _build_dataclass_signal_group( _equality_operators = dict(equality_operators) if equality_operators else {} signals = {} eq_map = _get_eq_operator_map(cls) + # create a Signal for each field in the dataclass for name, type_ in iter_fields(cls): if name in _equality_operators: if not callable(_equality_operators[name]): # pragma: no cover @@ -140,7 +154,10 @@ def _build_dataclass_signal_group( eq_map[name] = _equality_operators[name] else: eq_map[name] = _pick_equality_operator(type_) - signals[name] = Signal(object if type_ is None else type_) + field_type = object if type_ is None else type_ + signals[name] = sig = Signal(field_type) + # patch in our custom SignalInstance class with maxargs=1 on connect_setattr + sig._signal_instance_class = _DataclassFieldSignalInstance return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals) @@ -380,7 +397,8 @@ def _do_patch_setattr(self, owner: type) -> None: try: # assign a new __setattr__ method to the class owner.__setattr__ = evented_setattr( # type: ignore - self._name, owner.__setattr__ # type: ignore + cast(str, self._name), + owner.__setattr__, # type: ignore ) except Exception as e: # pragma: no cover # not sure what might cause this ... but it will have consequences diff --git a/src/psygnal/_signal.py b/src/psygnal/_signal.py index 3c74763a..489fafde 100644 --- a/src/psygnal/_signal.py +++ b/src/psygnal/_signal.py @@ -116,6 +116,7 @@ def __init__( self.description = description self._check_nargs_on_connect = check_nargs_on_connect self._check_types_on_connect = check_types_on_connect + self._signal_instance_class: type[SignalInstance] = SignalInstance if types and isinstance(types[0], Signature): self._signature = types[0] @@ -171,7 +172,7 @@ class Emitter: if instance is None: return self name = cast("str", self._name) - signal_instance = SignalInstance( + signal_instance = self._signal_instance_class( self.signature, instance=instance, name=name, diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index b1d83205..808b86a0 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -186,3 +186,29 @@ def b(self, value: int) -> None: b_mock.assert_not_called() # getter shouldn't have been called assert foo.b == 1 b_mock.assert_called_once_with(1) # getter should have been called only once + + +def test_evented_field_connect_setattr() -> None: + """Test that using connect_setattr""" + + @dataclass + class Foo: + a: int + events: ClassVar = SignalGroupDescriptor() + + class Bar: + x = 1 + y = 1 + + foo = Foo(a=1) + bar = Bar() + + foo.events.a.connect_setattr(bar, "x") + foo.events.a.connect_setattr(bar, "y", maxargs=None) + foo.events.a.emit(2, 1) + + assert bar.x == 2 # this is likely the desired outcome + # this is a bit of a gotcha, but it's the expected behavior + # when using connect_setattr with maxargs=None + # remove this test if/when we change maxargs to default to 1 on SignalInstance + assert bar.y == (2, 1) # type: ignore