Skip to content

Commit

Permalink
fix: fix connect_setattr on dataclass field signals (#258)
Browse files Browse the repository at this point in the history
* fix: fix connect_setattr on dataclass field signals

* add comment

* fix lint
  • Loading branch information
tlambert03 authored Jan 31, 2024
1 parent 8fdc1ea commit 59472ca
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
26 changes: 22 additions & 4 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

from ._dataclass_utils import iter_fields
from ._group import SignalGroup
from ._signal import Signal
from ._signal import Signal, SignalInstance

if TYPE_CHECKING:
from _weakref import ref as ref
from typing_extensions import Literal

from ._signal import SignalInstance
from psygnal._weak_callback import RefErrorChoice, WeakCallback


__all__ = ["is_evented", "get_evented_namespace", "SignalGroupDescriptor"]
Expand Down Expand Up @@ -125,6 +126,18 @@ def _pick_equality_operator(type_: type | None) -> EqOperator:
return operator.eq


class _DataclassFieldSignalInstance(SignalInstance):
def connect_setattr(
self,
obj: ref | object,
attr: str,
maxargs: int | None = 1,
*,
on_ref_error: RefErrorChoice = "warn",
) -> WeakCallback[None]:
return super().connect_setattr(obj, attr, maxargs, on_ref_error=on_ref_error)


@lru_cache(maxsize=None)
def _build_dataclass_signal_group(
cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None
Expand All @@ -133,14 +146,18 @@ def _build_dataclass_signal_group(
_equality_operators = dict(equality_operators) if equality_operators else {}
signals = {}
eq_map = _get_eq_operator_map(cls)
# create a Signal for each field in the dataclass
for name, type_ in iter_fields(cls):
if name in _equality_operators:
if not callable(_equality_operators[name]): # pragma: no cover
raise TypeError("EqOperator must be callable")
eq_map[name] = _equality_operators[name]
else:
eq_map[name] = _pick_equality_operator(type_)
signals[name] = Signal(object if type_ is None else type_)
field_type = object if type_ is None else type_
signals[name] = sig = Signal(field_type)
# patch in our custom SignalInstance class with maxargs=1 on connect_setattr
sig._signal_instance_class = _DataclassFieldSignalInstance

return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals)

Expand Down Expand Up @@ -380,7 +397,8 @@ def _do_patch_setattr(self, owner: type) -> None:
try:
# assign a new __setattr__ method to the class
owner.__setattr__ = evented_setattr( # type: ignore
self._name, owner.__setattr__ # type: ignore
cast(str, self._name),
owner.__setattr__, # type: ignore
)
except Exception as e: # pragma: no cover
# not sure what might cause this ... but it will have consequences
Expand Down
3 changes: 2 additions & 1 deletion src/psygnal/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
self.description = description
self._check_nargs_on_connect = check_nargs_on_connect
self._check_types_on_connect = check_types_on_connect
self._signal_instance_class: type[SignalInstance] = SignalInstance

if types and isinstance(types[0], Signature):
self._signature = types[0]
Expand Down Expand Up @@ -171,7 +172,7 @@ class Emitter:
if instance is None:
return self
name = cast("str", self._name)
signal_instance = SignalInstance(
signal_instance = self._signal_instance_class(
self.signature,
instance=instance,
name=name,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,29 @@ def b(self, value: int) -> None:
b_mock.assert_not_called() # getter shouldn't have been called
assert foo.b == 1
b_mock.assert_called_once_with(1) # getter should have been called only once


def test_evented_field_connect_setattr() -> None:
"""Test that using connect_setattr"""

@dataclass
class Foo:
a: int
events: ClassVar = SignalGroupDescriptor()

class Bar:
x = 1
y = 1

foo = Foo(a=1)
bar = Bar()

foo.events.a.connect_setattr(bar, "x")
foo.events.a.connect_setattr(bar, "y", maxargs=None)
foo.events.a.emit(2, 1)

assert bar.x == 2 # this is likely the desired outcome
# this is a bit of a gotcha, but it's the expected behavior
# when using connect_setattr with maxargs=None
# remove this test if/when we change maxargs to default to 1 on SignalInstance
assert bar.y == (2, 1) # type: ignore

0 comments on commit 59472ca

Please sign in to comment.