diff --git a/misc/codegen/lib/schemadefs.py b/misc/codegen/lib/schemadefs.py index 199043de7e6f..997b85b4ca6a 100644 --- a/misc/codegen/lib/schemadefs.py +++ b/misc/codegen/lib/schemadefs.py @@ -1,7 +1,7 @@ from typing import ( Callable as _Callable, Dict as _Dict, - List as _List, + Iterable as _Iterable, ClassVar as _ClassVar, ) from misc.codegen.lib import schema as _schema @@ -279,7 +279,7 @@ def __or__(self, other: _schema.PropertyModifier): drop = object() -def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]: +def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]: """ Add or modify schema annotations after a class has been defined previously. @@ -297,7 +297,7 @@ def decorator(cls: type) -> _PropertyAnnotation: if replace_bases: annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__) if add_bases: - annotated_cls.__bases__ = tuple(annotated_cls.__bases__) + tuple(add_bases) + annotated_cls.__bases__ += tuple(add_bases) for a in dir(cls): if a.startswith(_schema.inheritable_pragma_prefix): setattr(annotated_cls, a, getattr(cls, a)) diff --git a/misc/codegen/test/test_schemaloader.py b/misc/codegen/test/test_schemaloader.py index 0c9128c72727..6c6fccfb3eac 100644 --- a/misc/codegen/test/test_schemaloader.py +++ b/misc/codegen/test/test_schemaloader.py @@ -914,6 +914,36 @@ class _: } +def test_annotate_add_bases(): + @load + class data: + class Root: + pass + + class A(Root): + pass + + class B(Root): + pass + + class C(Root): + pass + + class Derived(A): + pass + + @defs.annotate(Derived, add_bases=(B, C)) + class _: + pass + assert data.classes == { + "Root": schema.Class("Root", derived={"A", "B", "C"}), + "A": schema.Class("A", bases=["Root"], derived={"Derived"}), + "B": schema.Class("B", bases=["Root"], derived={"Derived"}), + "C": schema.Class("C", bases=["Root"], derived={"Derived"}), + "Derived": schema.Class("Derived", bases=["A", "B", "C"]), + } + + def test_annotate_drop_field(): @load class data: diff --git a/rust/schema/annotations.py b/rust/schema/annotations.py index fe220233b571..f5e91c9928f8 100644 --- a/rust/schema/annotations.py +++ b/rust/schema/annotations.py @@ -1741,13 +1741,6 @@ class _: ``` """ -class Callable(AstNode): - """ - A callable. Either a `Function` or a `ClosureExpr`. - """ - param_list: optional["ParamList"] | child - attrs: list["Attr"] | child - @annotate(Function, add_bases=[Callable]) class _: param_list: drop diff --git a/rust/schema/prelude.py b/rust/schema/prelude.py index de905eb5b346..4f001ed2b5b8 100644 --- a/rust/schema/prelude.py +++ b/rust/schema/prelude.py @@ -63,3 +63,11 @@ class Unimplemented(Unextracted): The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted. """ pass + + +class Callable(AstNode): + """ + A callable. Either a `Function` or a `ClosureExpr`. + """ + param_list: optional["ParamList"] | child + attrs: list["Attr"] | child