Skip to content

Commit

Permalink
Rust: address review
Browse files Browse the repository at this point in the history
  • Loading branch information
Paolo Tranquilli committed Oct 15, 2024
1 parent 248eb7f commit bd08bc7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
6 changes: 3 additions & 3 deletions misc/codegen/lib/schemadefs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import (
Callable as _Callable,
Dict as _Dict,
List as _List,
Iterable as _Iterable,
ClassVar as _ClassVar,
)
from misc.codegen.lib import schema as _schema
Expand Down Expand Up @@ -279,7 +279,7 @@ def __or__(self, other: _schema.PropertyModifier):
drop = object()


def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
"""
Add or modify schema annotations after a class has been defined previously.
Expand All @@ -297,7 +297,7 @@ def decorator(cls: type) -> _PropertyAnnotation:
if replace_bases:
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
if add_bases:
annotated_cls.__bases__ = tuple(annotated_cls.__bases__) + tuple(add_bases)
annotated_cls.__bases__ += tuple(add_bases)
for a in dir(cls):
if a.startswith(_schema.inheritable_pragma_prefix):
setattr(annotated_cls, a, getattr(cls, a))
Expand Down
30 changes: 30 additions & 0 deletions misc/codegen/test/test_schemaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,36 @@ class _:
}


def test_annotate_add_bases():
@load
class data:
class Root:
pass

class A(Root):
pass

class B(Root):
pass

class C(Root):
pass

class Derived(A):
pass

@defs.annotate(Derived, add_bases=(B, C))
class _:
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "B", "C"}),
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
"B": schema.Class("B", bases=["Root"], derived={"Derived"}),
"C": schema.Class("C", bases=["Root"], derived={"Derived"}),
"Derived": schema.Class("Derived", bases=["A", "B", "C"]),
}


def test_annotate_drop_field():
@load
class data:
Expand Down
7 changes: 0 additions & 7 deletions rust/schema/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,13 +1741,6 @@ class _:
```
"""

class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child

@annotate(Function, add_bases=[Callable])
class _:
param_list: drop
Expand Down
8 changes: 8 additions & 0 deletions rust/schema/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ class Unimplemented(Unextracted):
The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted.
"""
pass


class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child

0 comments on commit bd08bc7

Please sign in to comment.