Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add simple validator metadata for use in annotated #316

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/psygnal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"SignalGroupDescriptor",
"SignalInstance",
"throttled",
"Validator",
]


Expand All @@ -51,7 +52,12 @@
from ._evented_decorator import evented
from ._exceptions import EmitLoopError
from ._group import EmissionInfo, SignalGroup
from ._group_descriptor import SignalGroupDescriptor, get_evented_namespace, is_evented
from ._group_descriptor import (
SignalGroupDescriptor,
Validator,
get_evented_namespace,
is_evented,
)
from ._queue import emit_queued
from ._signal import Signal, SignalInstance, _compiled
from ._throttler import debounced, throttled
Expand Down
1 change: 1 addition & 0 deletions src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _decorate(cls: T) -> T:
# as a decorator, this will have already been called
descriptor.__set_name__(cls, events_namespace)
setattr(cls, events_namespace, descriptor)
descriptor._do_patch_setattr(cls)
return cls

return _decorate(cls) if cls is not None else _decorate
98 changes: 97 additions & 1 deletion src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@

import contextlib
import copy
import inspect
import operator
import sys
import warnings
import weakref
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
ClassVar,
ForwardRef,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
cast,
get_args,
get_origin,
overload,
)

Expand Down Expand Up @@ -325,7 +332,10 @@ def __exit__(self, *args: Any) -> None:

@overload
def evented_setattr(
signal_group_name: str, super_setattr: SetAttr, with_aliases: bool = ...
signal_group_name: str,
super_setattr: SetAttr,
with_aliases: bool = ...,
validators: Mapping[str, Sequence[Validator]] | None = ...,
) -> SetAttr: ...


Expand All @@ -334,13 +344,15 @@ def evented_setattr(
signal_group_name: str,
super_setattr: Literal[None] | None = ...,
with_aliases: bool = ...,
validators: Mapping[str, Sequence[Validator]] | None = ...,
) -> Callable[[SetAttr], SetAttr]: ...


def evented_setattr(
signal_group_name: str,
super_setattr: SetAttr | None = None,
with_aliases: bool = True,
validators: Mapping[str, Sequence[Validator]] | None = None,
) -> SetAttr | Callable[[SetAttr], SetAttr]:
"""Create a new __setattr__ method that emits events when fields change.

Expand Down Expand Up @@ -374,7 +386,11 @@ def __getattr__(self, name: str) -> SignalInstanceProtocol: ...
Whether to lookup the signal name in the signal aliases mapping,
by default True. This is slightly slower, and so can be set to False if you
know you don't have any signal aliases.
validators: Mapping[str, Sequence[Validator]] | None
A mapping of field name to a sequence of validators to run on the value before
setting it. If None, no validators are run. Default to None
"""
validators = validators or {}

def _inner(super_setattr: SetAttr) -> SetAttr:
# don't patch twice
Expand All @@ -391,6 +407,9 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None:
if name == signal_group_name:
return super_setattr(self, name, value)

for validator in validators.get(name, ()):
value = validator(value, name=name, owner=self)

group = cast(SignalGroup, getattr(self, signal_group_name))
if not with_aliases and name not in group:
return super_setattr(self, name, value)
Expand Down Expand Up @@ -488,6 +507,12 @@ def __setattr__(self, name: str, value: Any) -> None:
field name. If the output is None, no signal is created for this field.
If None, defaults to an empty dict, no aliases.
Default to None
eager: bool | None, optional
If True, the SignalGroup will be created when the descriptor is set on the
class. If False, the SignalGroup will not be created until the first access of
the descriptor on an instance. If None, the SignalGroup will be created when
the descriptor is set on the class only if validators are found in the class
annotations. By default None

Examples
--------
Expand Down Expand Up @@ -524,6 +549,7 @@ def __init__(
signal_group_class: type[SignalGroup] | None = None,
collect_fields: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
eager: bool | None = None,
):
grp_cls = signal_group_class or SignalGroup
if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)):
Expand Down Expand Up @@ -552,6 +578,7 @@ def __init__(
self._signal_group_class: type[SignalGroup] = grp_cls
self._collect_fields = collect_fields
self._signal_aliases = signal_aliases
self._eager = eager

self._signal_groups: dict[int, type[SignalGroup]] = {}

Expand All @@ -561,6 +588,27 @@ def __set_name__(self, owner: type, name: str) -> None:
with contextlib.suppress(AttributeError):
# This is the flag that identifies this object as evented
setattr(owner, PSYGNAL_GROUP_NAME, name)
if self._eager is not False:
if self._find_validators(owner):
self._get_signal_group(owner)

def _find_validators(self, owner: type) -> dict[str, list[Validator]]:
validators: dict[str, list[Validator]] = {}
for field, annotation in owner.__annotations__.items():
try:
annotation = _resolve(annotation, owner)
if get_origin(annotation) is Annotated:
for item in get_args(annotation)[1:]:
if isinstance(item, Validator):
validators.setdefault(field, []).append(item)
except Exception:
warnings.warn(
f"Unable to resolve type annotation {annotation}"
"Psygnal Validator will not work",
stacklevel=2,
)

return validators

def _do_patch_setattr(self, owner: type, with_aliases: bool = True) -> None:
"""Patch the owner class's __setattr__ method to emit events."""
Expand All @@ -581,6 +629,7 @@ def _do_patch_setattr(self, owner: type, with_aliases: bool = True) -> None:
name,
owner.__setattr__, # type: ignore
with_aliases=with_aliases,
validators=self._find_validators(owner),
)
except Exception as e: # pragma: no cover
# not sure what might cause this ... but it will have consequences
Expand Down Expand Up @@ -656,3 +705,50 @@ def _create_group(self, owner: type) -> type[SignalGroup]:

self._do_patch_setattr(owner, with_aliases=bool(Group._psygnal_aliases))
return Group


@dataclass
class Validator:
"""Annotated metadata marking that a function validates a value before setting.

Examples
--------
```python
from psygnal import Validator, evented


def is_positive(value: int) -> int:
if not value > 0:
raise ValueError("Value must be positive")
return value


@evented
@dataclass
class Foo:
x: Annotated[int, Validator(is_positive)]
```

"""

func: Callable[[Any], Any]

def __call__(self, value: Any, *, name: str, owner: Any) -> Any:
"""Validate the input."""
try:
return self.func(value)
except Exception as e:
raise ValueError(
f"Error setting value {value!r} for field {name!r} "
f"on type {type(owner)}: {e}"
) from e


def _resolve(annotation: Any, owner: Any) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
if isinstance(annotation, ForwardRef):
guard: frozenset = frozenset()
_globals = inspect.getmodule(owner).__dict__
annotation = annotation._evaluate(_globals, {}, guard)
return annotation
50 changes: 50 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dataclasses import dataclass
from typing import Annotated, Any

import pytest

from psygnal import Validator, evented


def _is_positive(value: Any) -> int:
try:
_value = int(value)
except (ValueError, TypeError):
raise ValueError("Value must be an integer") from None
if not _value > 0:
raise ValueError("Value must be positive")
return _value


def test_validator():
@evented
@dataclass
class Foo:
x: Annotated[int, Validator(_is_positive)]

with pytest.raises(ValueError, match="Value must be positive"):
Foo(x=-1)
foo = Foo(x="1") # type: ignore
assert isinstance(foo.x, int)
with pytest.raises(ValueError):
foo.x = -1


def test_validator_resolution():
@evented
@dataclass
class Bar:
x: "Annotated[int, Validator(_is_positive)]"

with pytest.raises(ValueError, match="Value must be positive"):
Bar(x=-1)

def _local_func(value: Any) -> Any:
return value

with pytest.warns(UserWarning, match="Unable to resolve type"):

@evented
@dataclass
class Baz:
x: "Annotated[int, Validator(_local_func)]"
Loading