Skip to content

Commit

Permalink
fix: add and fix copy operators (#255)
Browse files Browse the repository at this point in the history
* fix: add and fix copy operators

* test: add tests

* fix: fix type annotation

* fix: fix import

* fix: add import of annotations from __future__
  • Loading branch information
Czaki authored Jan 31, 2024
1 parent b0d2057 commit cfd3e5b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 28 deletions.
26 changes: 15 additions & 11 deletions src/psygnal/containers/_evented_dict.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Dict that emits events when altered."""
from __future__ import annotations

from typing import (
Dict,
TYPE_CHECKING,
Iterable,
Iterator,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

if TYPE_CHECKING:
from typing import Self

from psygnal._group import SignalGroup
from psygnal._signal import Signal

Expand All @@ -38,13 +41,13 @@ class TypedMutableMapping(MutableMapping[_K, _V]):

def __init__(
self,
data: Optional[DictArg] = None,
data: DictArg | None = None,
*,
basetype: TypeOrSequenceOfTypes = (),
**kwargs: _V,
):
self._dict: Dict[_K, _V] = {}
self._basetypes: Tuple[Type[_V], ...] = (
self._dict: dict[_K, _V] = {}
self._basetypes: tuple[type[_V], ...] = (
tuple(basetype) if isinstance(basetype, Sequence) else (basetype,)
)
self.update({} if data is None else data, **kwargs)
Expand Down Expand Up @@ -76,19 +79,20 @@ def _type_check(self, value: _V) -> _V:
)
return value

def __newlike__(
self, mapping: MutableMapping[_K, _V]
) -> "TypedMutableMapping[_K, _V]":
def __newlike__(self, mapping: MutableMapping[_K, _V]) -> Self:
new = self.__class__()
# separating this allows subclasses to omit these from their `__init__`
new._basetypes = self._basetypes
new.update(mapping)
return new

def copy(self) -> "TypedMutableMapping[_K, _V]":
def copy(self) -> Self:
"""Return a shallow copy of the dictionary."""
return self.__newlike__(self)

def __copy__(self) -> Self:
return self.copy()


class DictEvents(SignalGroup):
"""Events available on [EventedDict][psygnal.containers.EventedDict].
Expand Down Expand Up @@ -145,7 +149,7 @@ class EventedDict(TypedMutableMapping[_K, _V]):

def __init__(
self,
data: Optional[DictArg] = None,
data: DictArg | None = None,
*,
basetype: TypeOrSequenceOfTypes = (),
**kwargs: _V,
Expand All @@ -172,4 +176,4 @@ def __delitem__(self, key: _K) -> None:
self.events.removed.emit(key, item)

def __repr__(self) -> str:
return f"EventedDict({super().__repr__()})"
return f"{self.__class__.__name__}({super().__repr__()})"
29 changes: 22 additions & 7 deletions src/psygnal/containers/_evented_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
"""
from __future__ import annotations # pragma: no cover

from typing import Any, Iterable, MutableSequence, TypeVar, Union, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Iterable,
MutableSequence,
TypeVar,
Union,
cast,
overload,
)

from psygnal._group import EmissionInfo, SignalGroup
from psygnal._signal import Signal, SignalInstance
Expand All @@ -32,6 +41,9 @@
_T = TypeVar("_T")
Index = Union[int, slice]

if TYPE_CHECKING:
from typing import Self


class ListEvents(SignalGroup):
"""Events available on [EventedList][psygnal.containers.EventedList].
Expand Down Expand Up @@ -133,10 +145,10 @@ def __getitem__(self, key: int) -> _T:
...

@overload
def __getitem__(self, key: slice) -> EventedList[_T]:
def __getitem__(self, key: slice) -> Self:
...

def __getitem__(self, key: Index) -> _T | EventedList[_T]:
def __getitem__(self, key: Index) -> _T | Self:
"""Return self[key]."""
result = self._data[key]
return self.__newlike__(result) if isinstance(result, list) else result
Expand Down Expand Up @@ -200,21 +212,24 @@ def _pre_remove(self, index: int) -> None:
if self._child_events:
self._disconnect_child_emitters(self[index])

def __newlike__(self, iterable: Iterable[_T]) -> EventedList[_T]:
def __newlike__(self, iterable: Iterable[_T]) -> Self:
"""Return new instance of same class."""
return self.__class__(iterable)

def copy(self) -> EventedList[_T]:
def copy(self) -> Self:
"""Return a shallow copy of the list."""
return self.__newlike__(self)

def __add__(self, other: Iterable[_T]) -> EventedList[_T]:
def __copy__(self) -> Self:
return self.copy()

def __add__(self, other: Iterable[_T]) -> Self:
"""Add other to self, return new object."""
copy = self.copy()
copy.extend(other)
return copy

def __iadd__(self, other: Iterable[_T]) -> EventedList[_T]:
def __iadd__(self, other: Iterable[_T]) -> Self:
"""Add other to self in place (self += other)."""
self.extend(other)
return self
Expand Down
19 changes: 9 additions & 10 deletions src/psygnal/containers/_evented_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from psygnal import Signal, SignalGroup

if TYPE_CHECKING:
from typing import Self

from typing_extensions import Final

_T = TypeVar("_T")
_Cls = TypeVar("_Cls", bound="_BaseMutableSet")


class BailType:
Expand Down Expand Up @@ -89,15 +90,13 @@ def _do_discard(self, item: _T) -> None:

# -------- To match set API

def __copy__(self: _Cls) -> _Cls:
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
return inst
def __copy__(self) -> Self:
return self.copy()

def copy(self: _Cls) -> _Cls:
def copy(self) -> Self:
return self.__class__(self)

def difference(self: _Cls, *s: Iterable[_T]) -> _Cls:
def difference(self, *s: Iterable[_T]) -> Self:
"""Return the difference of two or more sets as a new set.
(i.e. all elements that are in this set but not the others.)
Expand All @@ -110,7 +109,7 @@ def difference_update(self, *s: Iterable[_T]) -> None:
for i in chain(*s):
self.discard(i)

def intersection(self: _Cls, *s: Iterable[_T]) -> _Cls:
def intersection(self, *s: Iterable[_T]) -> Self:
"""Return the intersection of two sets as a new set.
(i.e. all elements that are in both sets.)
Expand All @@ -133,7 +132,7 @@ def issuperset(self, __s: Iterable[Any]) -> bool:
"""Report whether this set contains another set."""
return set(self).issuperset(__s)

def symmetric_difference(self: _Cls, __s: Iterable[_T]) -> _Cls:
def symmetric_difference(self, __s: Iterable[_T]) -> Self:
"""Return the symmetric difference of two sets as a new set.
(i.e. all elements that are in exactly one of the sets.)
Expand All @@ -150,7 +149,7 @@ def symmetric_difference_update(self, __s: Iterable[_T]) -> None:
for i in __s:
self.discard(i) if i in self else self.add(i)

def union(self: _Cls, *s: Iterable[_T]) -> _Cls:
def union(self, *s: Iterable[_T]) -> Self:
"""Return the union of sets as a new set.
(i.e. all elements that are in either set.)
Expand Down
10 changes: 10 additions & 0 deletions tests/containers/test_evented_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -111,3 +112,12 @@ def test_dict_remove_events(test_dict):
test_dict.pop("C")
test_dict.events.removing.emit.assert_called_with("C")
test_dict.events.removed.emit.assert_called_with("C", 3)


def test_copy_no_sync():
d1 = EventedDict({1: 1, 2: 2, 3: 3})
d2 = copy(d1)
d1[4] = 4
d1[3] = 4
assert len(d2) == 3
assert d2[3] == 3
8 changes: 8 additions & 0 deletions tests/containers/test_evented_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from copy import copy
from typing import List, cast
from unittest.mock import Mock, call

Expand Down Expand Up @@ -369,3 +370,10 @@ def __init__(self):
# attribute on signal instances.
assert e_obj.events.test2.instance.instance == e_obj
mock.assert_has_calls(expected)


def test_copy_no_sync():
l1 = EventedList([1, 2, 3])
l2 = copy(l1)
l1.append(4)
assert len(l2) == 3
8 changes: 8 additions & 0 deletions tests/containers/test_evented_set.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from unittest.mock import Mock, call

import pytest
Expand Down Expand Up @@ -133,3 +134,10 @@ def test_repr(test_set):
assert repr(test_set) == "EventedOrderedSet((0, 1, 2, 3, 4))"
else:
assert repr(test_set) == "EventedSet({0, 1, 2, 3, 4})"


def test_copy_no_sync():
s1 = EventedSet([1, 2, 3])
s2 = copy(s1)
s1.add(4)
assert len(s2) == 3

0 comments on commit cfd3e5b

Please sign in to comment.