diff --git a/pyproject.toml b/pyproject.toml index 7082f2e4..059f76b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ namespaces = false [tool.ruff.lint] # Enable the isort rules. -extend-select = ["I"] +extend-select = ["I", "UP"] [tool.ruff.lint.isort] known-first-party = ["spox"] diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 5ac9db00..32017dab 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Dict, List, Optional +from typing import Optional import numpy as np import onnx @@ -23,8 +23,8 @@ def adapt_node( proto: onnx.NodeProto, source_version: int, target_version: int, - var_names: Dict[Var, str], -) -> Optional[List[onnx.NodeProto]]: + var_names: dict[Var, str], +) -> Optional[list[onnx.NodeProto]]: if source_version == target_version: return None @@ -69,11 +69,11 @@ def adapt_node( def adapt_inline( node: _Inline, - protos: List[onnx.NodeProto], - target_opsets: Dict[str, int], - var_names: Dict[Var, str], + protos: list[onnx.NodeProto], + target_opsets: dict[str, int], + var_names: dict[Var, str], node_name: str, -) -> List[onnx.NodeProto]: +) -> list[onnx.NodeProto]: source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")}) target_version = target_opsets[""] @@ -97,11 +97,11 @@ def adapt_inline( def adapt_best_effort( node: Node, - protos: List[onnx.NodeProto], - opsets: Dict[str, int], - var_names: Dict[Var, str], - node_names: Dict[Node, str], -) -> Optional[List[onnx.NodeProto]]: + protos: list[onnx.NodeProto], + opsets: dict[str, int], + var_names: dict[Var, str], + node_names: dict[Node, str], +) -> Optional[list[onnx.NodeProto]]: if isinstance(node, _Inline): return adapt_inline( node, diff --git a/src/spox/_attributes.py b/src/spox/_attributes.py index bc38bbb0..7e97ff1c 100644 --- a/src/spox/_attributes.py +++ b/src/spox/_attributes.py @@ -3,7 +3,8 @@ import abc from abc import ABC -from typing import Any, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union +from collections.abc import Iterable +from typing import Any, Generic, Optional, TypeVar, Union import numpy as np import numpy.typing as npt @@ -45,7 +46,7 @@ def deref(self) -> "Attr": return self @classmethod - def maybe(cls: Type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]: + def maybe(cls: type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]: return cls(value, name) if value is not None else None @property @@ -200,15 +201,15 @@ def _to_onnx_deref(self) -> AttributeProto: ) -class _AttrIterable(Attr[Tuple[S, ...]], ABC): - def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]], name: str): +class _AttrIterable(Attr[tuple[S, ...]], ABC): + def __init__(self, value: Union[Iterable[S], _Ref[tuple[S, ...]]], name: str): super().__init__( value=value if isinstance(value, _Ref) else tuple(value), name=name ) @classmethod def maybe( - cls: Type[AttrIterableT], + cls: type[AttrIterableT], value: Optional[Iterable[S]], name: str, ) -> Optional[AttrIterableT]: diff --git a/src/spox/_build.py b/src/spox/_build.py index 5e579360..60c544a0 100644 --- a/src/spox/_build.py +++ b/src/spox/_build.py @@ -7,12 +7,8 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, - List, Optional, - Set, - Tuple, TypeVar, ) @@ -61,12 +57,12 @@ class BuildResult: """ scope: Scope - nodes: Dict[Node, Tuple[onnx.NodeProto, ...]] - arguments: Tuple[Var, ...] - results: Tuple[Var, ...] - opset_req: Set[Tuple[str, int]] - functions: Tuple["_function.Function", ...] - initializers: Dict[Var, np.ndarray] + nodes: dict[Node, tuple[onnx.NodeProto, ...]] + arguments: tuple[Var, ...] + results: tuple[Var, ...] + opset_req: set[tuple[str, int]] + functions: tuple["_function.Function", ...] + initializers: dict[Var, np.ndarray] class Builder: @@ -125,8 +121,8 @@ class ScopeTree: (lowest common ancestor), which is a common operation on trees. """ - subgraph_owner: Dict["Graph", Node] - scope_of: Dict[Node, "Graph"] + subgraph_owner: dict["Graph", Node] + scope_of: dict[Node, "Graph"] def __init__(self): self.subgraph_owner = {} @@ -165,18 +161,18 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph": # Graphs needed in the build main: "Graph" - graphs: Set["Graph"] - graph_topo: List["Graph"] + graphs: set["Graph"] + graph_topo: list["Graph"] # Arguments, results - arguments_of: Dict["Graph", List[Var]] - results_of: Dict["Graph", List[Var]] - source_of: Dict["Graph", Node] + arguments_of: dict["Graph", list[Var]] + results_of: dict["Graph", list[Var]] + source_of: dict["Graph", Node] # Arguments found by traversal - all_arguments_in: Dict["Graph", Set[Var]] - claimed_arguments_in: Dict["Graph", Set[Var]] + all_arguments_in: dict["Graph", set[Var]] + claimed_arguments_in: dict["Graph", set[Var]] # Scopes scope_tree: ScopeTree - scope_own: Dict["Graph", List[Node]] + scope_own: dict["Graph", list[Node]] def __init__(self, main: "Graph"): self.main = main @@ -207,8 +203,8 @@ def build_main(self) -> BuildResult: @staticmethod def get_intro_results( - request_results: Dict[str, Var], set_names: bool - ) -> List[Var]: + request_results: dict[str, Var], set_names: bool + ) -> list[Var]: """ Helper method for wrapping all requested results into a single Introduce and possibly naming them. @@ -222,7 +218,7 @@ def get_intro_results( var._rename(key) return vars - def discover(self, graph: "Graph") -> Tuple[Set[Var], Set[Var]]: + def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]: """ Run the discovery step of the build process. Resolves arguments and results for the involved graphs. Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes). @@ -359,7 +355,7 @@ def resolve_scopes(self) -> None: - this is slightly higher quality than a normal topological sorting which attempts to be "parallel", while a DFS' postorder is more "localised". """ - graph_scope_set: Dict[Any, Set[Node]] = {ctx: set() for ctx in self.graphs} + graph_scope_set: dict[Any, set[Node]] = {ctx: set() for ctx in self.graphs} for node, owner in self.scope_tree.scope_of.items(): graph_scope_set[owner].add(node) @@ -392,7 +388,7 @@ def resolve_scopes(self) -> None: def get_build_subgraph_callback( self, scope: Scope - ) -> Tuple[Callable, Set[Tuple[str, int]]]: + ) -> tuple[Callable, set[tuple[str, int]]]: """Create a callback for building subgraphs for ``Node.to_onnx``.""" subgraph_opset_req = set() # Keeps track of all opset imports in subgraphs @@ -432,11 +428,11 @@ def compile_graph( See the definition for the exact contents of the BuildResult dataclass. Used to build GraphProto/ModelProto from a Spox Graph. """ - nodes: Dict[Node, Tuple[onnx.NodeProto, ...]] = {} + nodes: dict[Node, tuple[onnx.NodeProto, ...]] = {} # A bunch of model metadata we're collecting - opset_req: Set[Tuple[str, int]] = set() - functions: List[_function.Function] = [] - initializers: Dict[Var, np.ndarray] = {} + opset_req: set[tuple[str, int]] = set() + functions: list[_function.Function] = [] + initializers: dict[Var, np.ndarray] = {} # Add arguments to our scope for arg in self.arguments_of[graph]: diff --git a/src/spox/_fields.py b/src/spox/_fields.py index ecb28815..d02ca742 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -3,8 +3,9 @@ import dataclasses import enum +from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass -from typing import Any, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union from ._attributes import Attr from ._var import Var @@ -17,7 +18,7 @@ class BaseFields: @dataclass class BaseAttributes(BaseFields): - def get_fields(self) -> Dict[str, Union[None, Attr]]: + def get_fields(self) -> dict[str, Union[None, Attr]]: """Return a mapping of all fields stored in this object by name.""" return self.__dict__.copy() @@ -68,7 +69,7 @@ def _get_field_type(cls, field) -> VarFieldKind: return VarFieldKind.VARIADIC raise ValueError(f"Bad field type: '{field.type}'.") - def _flatten(self) -> Iterable[Tuple[str, Optional[Var]]]: + def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]: """Iterate over the pairs of names and values of fields in this object.""" for key, value in self.__dict__.items(): if value is None or isinstance(value, Var): @@ -84,11 +85,11 @@ def __len__(self) -> int: """Count the number of fields in this object (should be same as declared in the class).""" return sum(1 for _ in self) - def get_vars(self) -> Dict[str, Var]: + def get_vars(self) -> dict[str, Var]: """Return a flat mapping by name of all the Vars in this object.""" return {key: var for key, var in self._flatten() if var is not None} - def get_fields(self) -> Dict[str, Union[None, Var, Sequence[Var]]]: + def get_fields(self) -> dict[str, Union[None, Var, Sequence[Var]]]: """Return a mapping of all fields stored in this object by name.""" return self.__dict__.copy() diff --git a/src/spox/_function.py b/src/spox/_function.py index ed31f1f0..79db4b9b 100644 --- a/src/spox/_function.py +++ b/src/spox/_function.py @@ -3,8 +3,9 @@ import inspect import itertools +from collections.abc import Iterable from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Callable, Dict, Iterable, Tuple, TypeVar +from typing import TYPE_CHECKING, Callable, TypeVar import onnx @@ -41,8 +42,8 @@ class Function(_InternalNode): via the ``to_onnx_function`` method. """ - func_args: Dict[str, Var] - func_attrs: Dict[str, _attributes.Attr] + func_args: dict[str, Var] + func_attrs: dict[str, _attributes.Attr] func_inputs: BaseInputs func_outputs: BaseOutputs func_graph: "_graph.Graph" @@ -60,7 +61,7 @@ def constructor(self, attrs, inputs): f"Function {type(self).__name__} does not implement a constructor." ) - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: from . import _graph self.func_args = _graph.arguments_dict( @@ -98,7 +99,7 @@ def update_metadata(self, opset_req, initializers, functions): functions.extend(self.func_graph._get_build_result().functions) def to_onnx_function( - self, *, extra_opset_req: Iterable[Tuple[str, int]] = () + self, *, extra_opset_req: Iterable[tuple[str, int]] = () ) -> onnx.FunctionProto: """ Translate self into an ONNX FunctionProto, based on the ``func_*`` attributes set when this operator diff --git a/src/spox/_future.py b/src/spox/_future.py index 3075b50b..ecaa5c7b 100644 --- a/src/spox/_future.py +++ b/src/spox/_future.py @@ -3,8 +3,9 @@ """Module containing experimental Spox features that may be standard in the future.""" +from collections.abc import Iterable from contextlib import contextmanager -from typing import Iterable, List, Optional, Union +from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -84,7 +85,7 @@ def _promote( Apply constant promotion and type promotion to given parameters, creating constants and/or casting. """ - targets: List[Union[np.dtype, np.generic, int, float]] = [ + targets: list[Union[np.dtype, np.generic, int, float]] = [ x.type.dtype if isinstance(x, Var) and isinstance(x.type, Tensor) else x # type: ignore for x in args ] diff --git a/src/spox/_graph.py b/src/spox/_graph.py index d487ed19..33369fd7 100644 --- a/src/spox/_graph.py +++ b/src/spox/_graph.py @@ -5,8 +5,9 @@ import dataclasses import itertools +from collections.abc import Iterable from dataclasses import dataclass, replace -from typing import Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union +from typing import Callable, Literal, Optional, Union import numpy as np import onnx @@ -24,7 +25,7 @@ from ._var import Var -def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> Dict[str, Var]: +def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var]: """ Parameters ---------- @@ -65,14 +66,14 @@ def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> Dict[str, Var return result -def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> Tuple[Var, ...]: +def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> tuple[Var, ...]: """This function is a shorthand for a respective call to ``arguments_dict``, unpacking the Vars from the dict.""" return tuple(arguments_dict(**kwargs).values()) def enum_arguments( *infos: Union[Type, np.ndarray], prefix: str = "in" -) -> Tuple[Var, ...]: +) -> tuple[Var, ...]: """ Convenience function for creating an enumeration of arguments, prefixed with ``prefix``. Calls ``arguments`` internally. @@ -132,11 +133,11 @@ class Graph: Note: building a Graph is cached, so changing it in-place without the setters will invalidate the build. """ - _results: Dict[str, Var] + _results: dict[str, Var] _name: Optional[str] = None _doc_string: Optional[str] = None - _arguments: Optional[Tuple[Var, ...]] = None - _extra_opset_req: Optional[Set[Tuple[str, int]]] = None + _arguments: Optional[tuple[Var, ...]] = None + _extra_opset_req: Optional[set[tuple[str, int]]] = None _constructor: Optional[Callable[..., Iterable[Var]]] = None _build_result: "_build.Cached[_build.BuildResult]" = dataclasses.field( default_factory=_build.Cached @@ -150,7 +151,7 @@ def __repr__(self): else "..." ) res_repr = f"{', '.join(f'{k}: {a}' for k, a in self._results.items())}" - comments: List[str] = [] + comments: list[str] = [] if self._doc_string is not None: comments.append(f'"{self._doc_string[:10]}..."') if self._extra_opset_req is not None: @@ -182,7 +183,7 @@ def with_arguments(self, *args: Var) -> "Graph": """ return replace(self, _arguments=args) - def with_opset(self, *args: Tuple[str, int]) -> "Graph": + def with_opset(self, *args: tuple[str, int]) -> "Graph": """ Add the given minimum opset requirements to the graph. Useful when the graph is using legacy nodes, but Spox should attempt to convert them to a required version. @@ -217,11 +218,11 @@ def requested_arguments(self) -> Optional[Iterable[Var]]: return self._arguments @property - def requested_results(self) -> Dict[str, Var]: + def requested_results(self) -> dict[str, Var]: """Results (named) requested by this Graph (for building).""" return self._results - def get_arguments(self) -> Dict[str, Var]: + def get_arguments(self) -> dict[str, Var]: """ Get the effective named arguments (after build) of this Graph. @@ -232,7 +233,7 @@ def get_arguments(self) -> Dict[str, Var]: for var in self._get_build_result().arguments } - def get_results(self) -> Dict[str, Var]: + def get_results(self) -> dict[str, Var]: """ Get the effective named results (after build) of this Graph. @@ -243,7 +244,7 @@ def get_results(self) -> Dict[str, Var]: for var in self._get_build_result().results } - def get_opsets(self) -> Dict[str, int]: + def get_opsets(self) -> dict[str, int]: """ Get the effective opsets used by this Graph. The used policy for mixed versions is maximum-requested. @@ -257,20 +258,20 @@ def _get_build_result(self) -> "_build.BuildResult": self._build_result.value = _build.Builder(self).build_main() return self._build_result.value - def _get_opset_req(self) -> Set[Tuple[str, int]]: + def _get_opset_req(self) -> set[tuple[str, int]]: """Internal function for accessing the opset requirements, including extras requested by the Graph itself.""" return self._get_build_result().opset_req | ( self._extra_opset_req if self._extra_opset_req is not None else set() ) - def _get_initializers_by_name(self) -> Dict[str, np.ndarray]: + def _get_initializers_by_name(self) -> dict[str, np.ndarray]: """Internal function for accessing the initializers by name in the build.""" return { self._get_build_result().scope.var[var]: init for var, init in self._get_build_result().initializers.items() } - def get_adapted_nodes(self) -> Dict[Node, Tuple[onnx.NodeProto, ...]]: + def get_adapted_nodes(self) -> dict[Node, tuple[onnx.NodeProto, ...]]: """ Do a best-effort at generating NodeProtos of consistent versions, matching ``self.opsets``. In essence, the policy is to upgrade to the highest used version. @@ -398,8 +399,8 @@ def to_onnx_model( "Consider adding an Identity operator if you are just copying arguments." ) - opset_req: List[tuple[str, int]] = list(opsets.items()) # type: ignore - function_protos: Dict[Tuple[str, str], onnx.FunctionProto] = {} + opset_req: list[tuple[str, int]] = list(opsets.items()) # type: ignore + function_protos: dict[tuple[str, str], onnx.FunctionProto] = {} for fun in self._get_build_result().functions: proto = fun.to_onnx_function(extra_opset_req=opset_req) if proto is None: diff --git a/src/spox/_inline.py b/src/spox/_inline.py index d3d7749d..f3077f7d 100644 --- a/src/spox/_inline.py +++ b/src/spox/_inline.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools +from collections.abc import Sequence from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Callable, Optional import onnx @@ -23,7 +24,7 @@ def rename_in_graph( rename: Callable[[str], str], *, rename_node: Optional[Callable[[str], str]] = None, - rename_op: Optional[Callable[[str, str], Tuple[str, str]]] = None, + rename_op: Optional[Callable[[str, str], tuple[str, str]]] = None, ) -> onnx.GraphProto: def rename_in_subgraph(subgraph): return rename_in_graph( @@ -105,12 +106,12 @@ def graph(self) -> onnx.GraphProto: return self.model.graph @property - def opset_req(self) -> Set[Tuple[str, int]]: + def opset_req(self) -> set[tuple[str, int]]: return {(imp.domain, imp.version) for imp in self.model.opset_import} | { ("", INTERNAL_MIN_OPSET) } - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: # First, type check that we match the ModelProto type requirements for i, var in zip(self.graph.input, self.inputs.inputs): if var.type is not None and not ( @@ -126,7 +127,7 @@ def infer_output_types(self) -> Dict[str, Type]: for k, o in enumerate(self.graph.output) } - def propagate_values(self) -> Dict[str, _value_prop.PropValueType]: + def propagate_values(self) -> dict[str, _value_prop.PropValueType]: if any( var.type is None or var._value is None for var in self.inputs.get_vars().values() @@ -146,15 +147,15 @@ def propagate_values(self) -> Dict[str, _value_prop.PropValueType]: def to_onnx( self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None - ) -> List[onnx.NodeProto]: - input_names: Dict[str, int] = { + ) -> list[onnx.NodeProto]: + input_names: dict[str, int] = { p.name: i for i, p in enumerate(self.graph.input) } - output_names: Dict[str, int] = { + output_names: dict[str, int] = { p.name: i for i, p in enumerate(self.graph.output) } - inner_renames: Dict[str, str] = {} - inner_node_renames: Dict[str, str] = {} + inner_renames: dict[str, str] = {} + inner_node_renames: dict[str, str] = {} def reserve_prefixed(name: str) -> str: if not name: @@ -183,5 +184,5 @@ def apply_node_rename(name: str) -> str: raise BuildError( "Inlined graph initializers should be handled beforehand and be removed from the graph." ) - nodes: List[onnx.NodeProto] = list(graph.node) + nodes: list[onnx.NodeProto] = list(graph.node) return nodes diff --git a/src/spox/_internal_op.py b/src/spox/_internal_op.py index cdb54f62..f51f2579 100644 --- a/src/spox/_internal_op.py +++ b/src/spox/_internal_op.py @@ -7,8 +7,9 @@ """ from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Set, Tuple +from typing import Optional import onnx @@ -49,7 +50,7 @@ class _InternalNode(Node, ABC): @property - def opset_req(self) -> Set[Tuple[str, int]]: + def opset_req(self) -> set[tuple[str, int]]: return set() @@ -87,7 +88,7 @@ def post_init(self, **kwargs): if self.attrs.name is not None: self.outputs.arg._rename(self.attrs.name.value) - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: # Output type is based on the value of the type attribute return {"arg": self.attrs.type.value} @@ -99,7 +100,7 @@ def update_metadata(self, opset_req, initializers, functions): def to_onnx( self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None - ) -> List[onnx.NodeProto]: + ) -> list[onnx.NodeProto]: return [] @@ -120,12 +121,12 @@ class Outputs(BaseOutputs): inputs: BaseInputs outputs: Outputs - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: # Output type is based on the value of the type attribute arr = self.attrs.value.value return {"arg": Tensor(arr.dtype, arr.shape)} - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: return {"arg": self.attrs.value.value} def update_metadata(self, opset_req, initializers, functions): @@ -134,7 +135,7 @@ def update_metadata(self, opset_req, initializers, functions): def to_onnx( self, scope: "Scope", doc_string: Optional[str] = None, build_subgraph=None - ) -> List[onnx.NodeProto]: + ) -> list[onnx.NodeProto]: # Initializers are added via update_metadata and don't affect the nodes proto list return [] @@ -160,7 +161,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: return { f"outputs_{i}": arr.type for i, arr in enumerate(self.inputs.inputs) @@ -168,12 +169,12 @@ def infer_output_types(self) -> Dict[str, Type]: } @property - def opset_req(self) -> Set[Tuple[str, int]]: + def opset_req(self) -> set[tuple[str, int]]: return {("", INTERNAL_MIN_OPSET)} def to_onnx( self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None - ) -> List[onnx.NodeProto]: + ) -> list[onnx.NodeProto]: assert len(self.inputs.inputs) == len(self.outputs.outputs) # Just create a renaming identity from what we forwarded into our actual output protos = [] diff --git a/src/spox/_node.py b/src/spox/_node.py index 678ba209..322a09a9 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -8,8 +8,9 @@ import typing import warnings from abc import ABC +from collections.abc import Iterable, Sequence from dataclasses import dataclass -from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union +from typing import ClassVar, Optional, Union import onnx @@ -74,16 +75,16 @@ class Node(ABC): op_type: ClassVar[OpType] = OpType("", "", 0) - Attributes: ClassVar[typing.Type[BaseAttributes]] - Inputs: ClassVar[typing.Type[BaseInputs]] - Outputs: ClassVar[typing.Type[BaseOutputs]] + Attributes: ClassVar[type[BaseAttributes]] + Inputs: ClassVar[type[BaseInputs]] + Outputs: ClassVar[type[BaseOutputs]] attrs: BaseAttributes inputs: BaseInputs outputs: BaseOutputs out_variadic: Optional[int] - _traceback: Union[List[str], None] + _traceback: Union[list[str], None] def __init__( self, @@ -143,7 +144,7 @@ def __init__( self.post_init(**kwargs) @property - def opset_req(self) -> Set[Tuple[str, int]]: + def opset_req(self) -> set[tuple[str, int]]: """ Set of the opset requirements -- (domain, version) -- brought in by this node. Does not include subgraphs. @@ -211,7 +212,7 @@ def pre_init(self, **kwargs): def post_init(self, **kwargs): """Post-initialization hook. Called at the end of ``__init__`` after other default fields are set.""" - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: """ Propagate values from inputs, and, if possible, compute values for outputs as well. This method is used to implement ONNX partial data propagation - for example so that @@ -219,7 +220,7 @@ def propagate_values(self) -> Dict[str, PropValueType]: """ return {} - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: """ Inference routine for output types. Often overriden by inheriting Node types. @@ -308,7 +309,7 @@ def _init_output_vars(self) -> BaseOutputs: (variadic,) = variadics else: variadic = None - outputs: Dict[str, Union[Var, Sequence[Var]]] = { + outputs: dict[str, Union[Var, Sequence[Var]]] = { field.name: Var(self, None, None) for field in dataclasses.fields(self.Outputs) if field.name != variadic @@ -351,7 +352,7 @@ def to_onnx( build_subgraph: Optional[ typing.Callable[["Node", str, "Graph"], onnx.GraphProto] ] = None, - ) -> List[onnx.NodeProto]: + ) -> list[onnx.NodeProto]: """Translates self into an ONNX NodeProto.""" assert self.op_type.identifier input_names = [scope.var[var] if var is not None else "" for var in self.inputs] diff --git a/src/spox/_public.py b/src/spox/_public.py index 3157f4a0..101d8d40 100644 --- a/src/spox/_public.py +++ b/src/spox/_public.py @@ -5,7 +5,7 @@ import contextlib import itertools -from typing import Dict, List, Optional, Protocol +from typing import Optional, Protocol import numpy as np import onnx @@ -46,7 +46,7 @@ def _temporary_renames(**kwargs: Var): # not just ``Var._name``. So we set names here and reset them # afterwards. name: Optional[str] - pre: Dict[Var, Optional[str]] = {} + pre: dict[Var, Optional[str]] = {} try: for name, arg in kwargs.items(): pre[arg] = arg._name @@ -58,7 +58,7 @@ def _temporary_renames(**kwargs: Var): def build( - inputs: Dict[str, Var], outputs: Dict[str, Var], *, drop_unused_inputs=False + inputs: dict[str, Var], outputs: dict[str, Var], *, drop_unused_inputs=False ) -> onnx.ModelProto: """ Builds an ONNX Model with given model inputs and outputs. @@ -146,7 +146,7 @@ class _InlineCall(Protocol): (``str``) into ``Var``. """ - def __call__(self, *args: Var, **kwargs: Var) -> Dict[str, Var]: + def __call__(self, *args: Var, **kwargs: Var) -> dict[str, Var]: """ Parameters ---------- @@ -254,7 +254,7 @@ def inline(model: onnx.ModelProto) -> _InlineCall: ) # We handle everything related to initializers here, as currently build does not support them too well # Overridable initializers are saved to in_defaults, non-overridable replaced with Constant - preamble: List[onnx.NodeProto] = [] + preamble: list[onnx.NodeProto] = [] input_names = {i.name for i in model.graph.input} preamble.extend( onnx.helper.make_node("Constant", [], [i.name], value=i) @@ -275,7 +275,7 @@ def inline(model: onnx.ModelProto) -> _InlineCall: model.graph.node.reverse() # Now we can assume the graph has no initializers - def inline_inner(*args: Var, **kwargs: Var) -> Dict[str, Var]: + def inline_inner(*args: Var, **kwargs: Var) -> dict[str, Var]: for name, arg in zip(in_names, args): if name in kwargs: raise TypeError( diff --git a/src/spox/_schemas.py b/src/spox/_schemas.py index 6fa90c01..8f69622f 100644 --- a/src/spox/_schemas.py +++ b/src/spox/_schemas.py @@ -4,15 +4,11 @@ """Exposes information related to reference ONNX operator schemas, used by StandardOpNode.""" import itertools +from collections.abc import Iterable from typing import ( Callable, - Dict, - Iterable, - List, Optional, Protocol, - Set, - Tuple, TypeVar, ) @@ -52,8 +48,8 @@ def _current_schema( def _get_schemas_versioned( - all_schemas: List[OpSchema], -) -> Dict[str, Dict[str, List[OpSchema]]]: + all_schemas: list[OpSchema], +) -> dict[str, dict[str, list[OpSchema]]]: """Get a map into a list of schemas for all domain/names.""" return { domain: { @@ -65,9 +61,9 @@ def _get_schemas_versioned( def _get_schemas_map( - schemas_ver_lists: Dict[str, Dict[str, List[OpSchema]]], - domain_versions: Dict[str, Set[int]], -) -> Dict[str, Dict[int, Dict[str, OpSchema]]]: + schemas_ver_lists: dict[str, dict[str, list[OpSchema]]], + domain_versions: dict[str, set[int]], +) -> dict[str, dict[int, dict[str, OpSchema]]]: """Get a map into a schema for every domain/version/name.""" return { domain: { @@ -84,12 +80,12 @@ def _get_schemas_map( } -ALL_SCHEMAS: List[OpSchema] = get_all_schemas_with_history() # type: ignore +ALL_SCHEMAS: list[OpSchema] = get_all_schemas_with_history() # type: ignore -DOMAINS: Set[str] = {s.domain for s in ALL_SCHEMAS} +DOMAINS: set[str] = {s.domain for s in ALL_SCHEMAS} # Assumes that each version does change at least one of the operators from the available schemes. -DOMAIN_VERSIONS: Dict[str, Set[int]] = { +DOMAIN_VERSIONS: dict[str, set[int]] = { domain: {s.since_version for s in ALL_SCHEMAS if s.domain == domain} for domain in DOMAINS } @@ -101,7 +97,7 @@ def _get_schemas_map( SCHEMAS = _get_schemas_map(SCHEMAS_VER_LISTS, DOMAIN_VERSIONS) -def max_opset_policy(opset_req: Set[Tuple[str, int]]) -> Dict[str, int]: +def max_opset_policy(opset_req: set[tuple[str, int]]) -> dict[str, int]: """Use the highest required version for every opset.""" opset_req = {(k if k != "ai.onnx" else "", v) for k, v in opset_req} grouping = itertools.groupby(sorted(opset_req), key=lambda x: x[0]) diff --git a/src/spox/_scope.py b/src/spox/_scope.py index 6fdf594e..50b1ad52 100644 --- a/src/spox/_scope.py +++ b/src/spox/_scope.py @@ -1,7 +1,8 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Dict, Generic, Hashable, Optional, Set, TypeVar, Union, overload +from collections.abc import Hashable +from typing import Generic, Optional, TypeVar, Union, overload from ._node import Node from ._var import Var @@ -23,17 +24,17 @@ class ScopeSpace(Generic[H]): So ``__getitem__`` (``ScopeSpace[item]``) may be used for both the name of an object and the object of a name. """ - name_of: Dict[H, str] - of_name: Dict[str, H] - reserved: Set[str] - base_name_counters: Dict[str, int] + name_of: dict[H, str] + of_name: dict[str, H] + reserved: set[str] + base_name_counters: dict[str, int] parent: "Optional[ScopeSpace[H]]" def __init__( self, - name_of: Optional[Dict[H, str]] = None, - of_name: Optional[Dict[str, H]] = None, - reserved: Optional[Set[str]] = None, + name_of: Optional[dict[H, str]] = None, + of_name: Optional[dict[str, H]] = None, + reserved: Optional[set[str]] = None, parent: "Optional[ScopeSpace[H]]" = None, ): """ diff --git a/src/spox/_shape.py b/src/spox/_shape.py index db2467ee..c48308bf 100644 --- a/src/spox/_shape.py +++ b/src/spox/_shape.py @@ -15,14 +15,13 @@ """ import abc -import typing from dataclasses import dataclass -from typing import Optional, Tuple, TypeVar, Union +from typing import Optional, TypeVar, Union import onnx SimpleShapeElem = Union[str, int, None] -SimpleShape = Optional[Tuple[SimpleShapeElem, ...]] +SimpleShape = Optional[tuple[SimpleShapeElem, ...]] class ShapeError(TypeError): @@ -126,22 +125,20 @@ def __le__(self, other: Natural) -> bool: class Shape: """Type representing a static Tensor shape.""" - dims: Optional[Tuple[Natural, ...]] + dims: Optional[tuple[Natural, ...]] def __bool__(self): return self.dims is not None @classmethod - def from_simple(cls: typing.Type[ShapeT], shape: SimpleShape) -> ShapeT: + def from_simple(cls: type[ShapeT], shape: SimpleShape) -> ShapeT: """Translate into a Shape from the simplified representation.""" return cls( tuple(Natural.from_simple(v) for v in shape) if shape is not None else None ) @classmethod - def from_onnx( - cls: typing.Type[ShapeT], proto: Optional[onnx.TensorShapeProto] - ) -> ShapeT: + def from_onnx(cls: type[ShapeT], proto: Optional[onnx.TensorShapeProto]) -> ShapeT: """Translate into a Shape from ONNX shape.""" return ( cls(tuple(Natural.from_onnx(dim) for dim in proto.dim)) diff --git a/src/spox/_standard.py b/src/spox/_standard.py index 54e27df3..ac519875 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -3,7 +3,7 @@ """Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference.""" -from typing import TYPE_CHECKING, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Callable import numpy as np import onnx @@ -51,7 +51,7 @@ def min_output(self) -> int: def to_singleton_onnx_model( self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True - ) -> Tuple[onnx.ModelProto, Scope]: + ) -> tuple[onnx.ModelProto, Scope]: """ Build a singleton model consisting of just this StandardNode. Used for type inference. Dummy subgraphs are typed, but have no graph body, so that we can avoid the build cost. @@ -123,7 +123,7 @@ def out_value_info(curr_key, curr_var): ) return model, scope - def infer_output_types_onnx(self) -> Dict[str, Type]: + def infer_output_types_onnx(self) -> dict[str, Type]: """Execute type & shape inference with ``onnx.shape_inference.infer_node_outputs``.""" # Check that all (specified) inputs have known types, as otherwise we fail if any(var.type is None for var in self.inputs.get_vars().values()): @@ -153,7 +153,7 @@ def infer_output_types_onnx(self) -> Dict[str, Type]: for key, type_ in results.items() } - def propagate_values_onnx(self) -> Dict[str, PropValueType]: + def propagate_values_onnx(self) -> dict[str, PropValueType]: """Perform value propagation by evaluating singleton model. The backend used for the propagation can be configured with the `spox._standard.ValuePropBackend` variable. @@ -185,10 +185,10 @@ def propagate_values_onnx(self) -> Dict[str, PropValueType]: } return {k: v for k, v in results.items() if k is not None} - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: return self.infer_output_types_onnx() - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: if _value_prop._VALUE_PROP_BACKEND != _value_prop.ValuePropBackend.NONE: return self.propagate_values_onnx() return {} diff --git a/src/spox/_traverse.py b/src/spox/_traverse.py index 214f3522..1d46bb54 100644 --- a/src/spox/_traverse.py +++ b/src/spox/_traverse.py @@ -1,7 +1,8 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from collections.abc import Iterable, Iterator +from typing import Callable, Optional, TypeVar V = TypeVar("V") @@ -11,7 +12,7 @@ def iterative_dfs( adj: Callable[[V], Iterable[V]], post_callback: Optional[Callable[[V], None]] = None, raise_on_cycle: bool = True, -) -> List[V]: +) -> list[V]: """ Performs a depth-first search and returns the postorder of the traversal. Throws if the graph contains a cycle. The topological sorting returned is the postorder of the DFS. @@ -55,12 +56,12 @@ def dfs(u: V): for s in sources: dfs(s) """ - postorder: List[V] = [] - visited: Set[V] = set() - stack: Set[V] = set() + postorder: list[V] = [] + visited: set[V] = set() + stack: set[V] = set() # Recursion stack - the state of a DFS is described with a stack of (vertex, nodes left to visit). - recursion: List[Tuple[V, Iterator[V]]] = [] + recursion: list[tuple[V, Iterator[V]]] = [] def call(w: V): """Helper called when we attempt to enter a node ``w``.""" diff --git a/src/spox/_type_system.py b/src/spox/_type_system.py index 8e883a67..27c2695b 100644 --- a/src/spox/_type_system.py +++ b/src/spox/_type_system.py @@ -1,7 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -import typing from dataclasses import dataclass from typing import TypeVar @@ -170,7 +169,7 @@ class Tensor(Type): enforced. """ - _elem_type: typing.Type[np.generic] + _elem_type: type[np.generic] _shape: Shape def __init__( diff --git a/src/spox/_value_prop.py b/src/spox/_value_prop.py index e781e077..300abccc 100644 --- a/src/spox/_value_prop.py +++ b/src/spox/_value_prop.py @@ -5,7 +5,7 @@ import logging import warnings from dataclasses import dataclass -from typing import Callable, Dict, List, Union +from typing import Callable, Union import numpy as np import numpy.typing as npt @@ -24,7 +24,7 @@ - PropValue -> Optional, Some (has value) - None -> Optional, Nothing (no value) """ -PropValueType = Union[np.ndarray, List["PropValue"], "PropValue", None] +PropValueType = Union[np.ndarray, list["PropValue"], "PropValue", None] ORTValue = Union[np.ndarray, list, None] RefValue = Union[np.ndarray, list, float, None] @@ -159,8 +159,8 @@ def to_ort_value(self) -> ORTValue: def _run_reference_implementation( - model: onnx.ModelProto, input_feed: Dict[str, RefValue] -) -> Dict[str, RefValue]: + model: onnx.ModelProto, input_feed: dict[str, RefValue] +) -> dict[str, RefValue]: try: session = onnx.reference.ReferenceEvaluator(model) output_feed = dict(zip(session.output_names, session.run(None, input_feed))) @@ -175,8 +175,8 @@ def _run_reference_implementation( def _run_onnxruntime( - model: onnx.ModelProto, input_feed: Dict[str, ORTValue] -) -> Dict[str, ORTValue]: + model: onnx.ModelProto, input_feed: dict[str, ORTValue] +) -> dict[str, ORTValue]: import onnxruntime # Silence possible warnings during execution (especially constant folding) @@ -196,7 +196,7 @@ def _run_onnxruntime( def get_backend_calls(): - run: Callable[..., Dict[str, npt.ArrayLike]] + run: Callable[..., dict[str, npt.ArrayLike]] unwrap_feed: Callable[..., PropValue] if _VALUE_PROP_BACKEND == ValuePropBackend.REFERENCE: wrap_feed = PropValue.to_ref_value diff --git a/src/spox/_var.py b/src/spox/_var.py index 4225ecc3..15dd2186 100644 --- a/src/spox/_var.py +++ b/src/spox/_var.py @@ -200,7 +200,7 @@ def __rxor__(self, other) -> "Var": def result_type( *types: Union[Var, np.generic, int, float], -) -> typing.Type[np.generic]: +) -> type[np.generic]: """Promote type for all given element types/values using ``np.result_type``.""" return np.dtype( np.result_type( diff --git a/src/spox/opset/ai/onnx/ml/v3.py b/src/spox/opset/ai/onnx/ml/v3.py index b34f1a46..b34d4020 100644 --- a/src/spox/opset/ai/onnx/ml/v3.py +++ b/src/spox/opset/ai/onnx/ml/v3.py @@ -2,13 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import ( - Dict, - Iterable, Optional, - Sequence, - Tuple, ) import numpy as np @@ -43,7 +40,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Z: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if not self.inputs.fully_typed: return {} xt, yt = self.inputs.X.unwrap_tensor(), self.inputs.Y.unwrap_tensor() @@ -78,7 +75,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: return {"Y": self.inputs.X.type} if self.inputs.X.type is not None else {} op_type = OpType("Binarizer", "ai.onnx.ml", 1) @@ -126,7 +123,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if not self.inputs.fully_typed: return {} cats1, cats2 = self.attrs.cats_int64s, self.attrs.cats_strings @@ -202,7 +199,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if not self.inputs.fully_typed: return {} t = self.inputs.X.unwrap_tensor() @@ -314,7 +311,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if not self.inputs.fully_typed: return {} sim = self.inputs.X.unwrap_tensor().shape @@ -348,7 +345,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if self.attrs.norm.value not in ("MAX", "L1", "L2"): raise InferenceError( f"Unknown normalisation method `{self.attrs.norm.value}`" @@ -377,7 +374,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if not self.inputs.fully_typed: return {} if self.attrs.cats_int64s: @@ -470,7 +467,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if self.inputs.X.type is None: return {} sc, off = self.attrs.scale, self.attrs.offset @@ -530,7 +527,7 @@ class Outputs(BaseOutputs): Y: Var Z: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: e = ( len(self.attrs.class_ids.value) if self.attrs.class_ids is not None @@ -594,7 +591,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): Y: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: if self.inputs.fully_typed: shape = self.inputs.X.unwrap_tensor().shape assert shape is not None # already checked with fully_typed @@ -1134,7 +1131,7 @@ def linear_classifier( intercepts: Optional[Iterable[float]] = None, multi_class: int = 0, post_transform: str = "NONE", -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Linear classifier @@ -1384,7 +1381,7 @@ def svmclassifier( rho: Optional[Iterable[float]] = None, support_vectors: Optional[Iterable[float]] = None, vectors_per_class: Optional[Iterable[int]] = None, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Support Vector Machine classifier @@ -1630,7 +1627,7 @@ def tree_ensemble_classifier( nodes_values: Optional[Iterable[float]] = None, nodes_values_as_tensor: Optional[np.ndarray] = None, post_transform: str = "NONE", -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Tree Ensemble classifier. Returns the top class for each of N inputs. The attributes named 'nodes_X' form a sequence of tuples, associated by diff --git a/src/spox/opset/ai/onnx/ml/v4.py b/src/spox/opset/ai/onnx/ml/v4.py index e20c3529..9e51382c 100644 --- a/src/spox/opset/ai/onnx/ml/v4.py +++ b/src/spox/opset/ai/onnx/ml/v4.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable from dataclasses import dataclass from typing import ( - Iterable, Optional, ) diff --git a/src/spox/opset/ai/onnx/ml/v5.py b/src/spox/opset/ai/onnx/ml/v5.py index fd4be62f..100bf179 100644 --- a/src/spox/opset/ai/onnx/ml/v5.py +++ b/src/spox/opset/ai/onnx/ml/v5.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable from dataclasses import dataclass from typing import ( - Iterable, Optional, ) diff --git a/src/spox/opset/ai/onnx/v17.py b/src/spox/opset/ai/onnx/v17.py index 71c0d5bd..f9cbd0bf 100644 --- a/src/spox/opset/ai/onnx/v17.py +++ b/src/spox/opset/ai/onnx/v17.py @@ -2,15 +2,11 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import ( Callable, - Dict, - Iterable, - List, Optional, - Sequence, - Tuple, ) from typing import cast as typing_cast @@ -498,7 +494,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): output: Var - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: self.infer_output_types_onnx() inp, cond = ( self.inputs.input.unwrap_tensor(), @@ -589,7 +585,7 @@ class Attributes(BaseAttributes): class Outputs(BaseOutputs): output: Var - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: ((key, raw),) = ( (k, v.value) for k, v in self.attrs.get_fields().items() if v is not None ) @@ -1778,7 +1774,7 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[Var] - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: output_types = super().infer_output_types() body = self.attrs.body.value @@ -4533,7 +4529,7 @@ def batch_normalization( epsilon: float = 9.999999747378752e-06, momentum: float = 0.8999999761581421, training_mode: int = 0, -) -> Tuple[Var, Var, Var]: +) -> tuple[Var, Var, Var]: r""" Carries out batch normalization as described in the paper https://arxiv.org/abs/1502.03167. Depending on the mode it is being run, @@ -6202,7 +6198,7 @@ def dropout( training_mode: Optional[Var] = None, *, seed: Optional[int] = None, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Dropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean @@ -6286,7 +6282,7 @@ def dropout( def dynamic_quantize_linear( x: Var, -) -> Tuple[Var, Var, Var]: +) -> tuple[Var, Var, Var]: r""" A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data. Outputs Scale, ZeroPoint and Quantized @@ -6793,7 +6789,7 @@ def gru( hidden_size: Optional[int] = None, layout: int = 0, linear_before_reset: int = 0, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Computes an one-layer GRU. This operator is usually supported via some custom implementation such as CuDNN. @@ -8300,7 +8296,7 @@ def lstm( hidden_size: Optional[int] = None, input_forget: int = 0, layout: int = 0, -) -> Tuple[Var, Var, Var]: +) -> tuple[Var, Var, Var]: r""" Computes an one-layer LSTM. This operator is usually supported via some custom implementation such as CuDNN. @@ -8510,7 +8506,7 @@ def layer_normalization( axis: int = -1, epsilon: float = 9.999999747378752e-06, stash_type: int = 1, -) -> Tuple[Var, Var, Var]: +) -> tuple[Var, Var, Var]: r""" This is layer normalization defined in ONNX as function. The overall computation can be split into two stages. The first stage is @@ -8990,7 +8986,7 @@ def loop( - V: `optional(seq(tensor(bfloat16)))`, `optional(seq(tensor(bool)))`, `optional(seq(tensor(complex128)))`, `optional(seq(tensor(complex64)))`, `optional(seq(tensor(double)))`, `optional(seq(tensor(float)))`, `optional(seq(tensor(float16)))`, `optional(seq(tensor(int16)))`, `optional(seq(tensor(int32)))`, `optional(seq(tensor(int64)))`, `optional(seq(tensor(int8)))`, `optional(seq(tensor(string)))`, `optional(seq(tensor(uint16)))`, `optional(seq(tensor(uint32)))`, `optional(seq(tensor(uint64)))`, `optional(seq(tensor(uint8)))`, `optional(tensor(bfloat16))`, `optional(tensor(bool))`, `optional(tensor(complex128))`, `optional(tensor(complex64))`, `optional(tensor(double))`, `optional(tensor(float))`, `optional(tensor(float16))`, `optional(tensor(int16))`, `optional(tensor(int32))`, `optional(tensor(int64))`, `optional(tensor(int8))`, `optional(tensor(string))`, `optional(tensor(uint16))`, `optional(tensor(uint32))`, `optional(tensor(uint64))`, `optional(tensor(uint8))`, `seq(tensor(bfloat16))`, `seq(tensor(bool))`, `seq(tensor(complex128))`, `seq(tensor(complex64))`, `seq(tensor(double))`, `seq(tensor(float))`, `seq(tensor(float16))`, `seq(tensor(int16))`, `seq(tensor(int32))`, `seq(tensor(int64))`, `seq(tensor(int8))`, `seq(tensor(string))`, `seq(tensor(uint16))`, `seq(tensor(uint32))`, `seq(tensor(uint64))`, `seq(tensor(uint8))`, `tensor(bfloat16)`, `tensor(bool)`, `tensor(complex128)`, `tensor(complex64)`, `tensor(double)`, `tensor(float)`, `tensor(float16)`, `tensor(int16)`, `tensor(int32)`, `tensor(int64)`, `tensor(int8)`, `tensor(string)`, `tensor(uint16)`, `tensor(uint32)`, `tensor(uint64)`, `tensor(uint8)` """ _body_subgraph: Graph = subgraph( - typing_cast(List[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + typing_cast(list[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + [var.unwrap_type() for var in v_initial], body, ) @@ -9286,7 +9282,7 @@ def max_pool( pads: Optional[Iterable[int]] = None, storage_order: int = 0, strides: Optional[Iterable[int]] = None, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" MaxPool consumes an input tensor X and applies max pooling across the tensor according to kernel sizes, stride sizes, and pad lengths. max @@ -11104,7 +11100,7 @@ def rnn( direction: str = "forward", hidden_size: Optional[int] = None, layout: int = 0, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Computes an one-layer simple RNN. This operator is usually supported via some custom implementation such as CuDNN. @@ -14097,7 +14093,7 @@ def softmax_cross_entropy_loss( *, ignore_index: Optional[int] = None, reduction: str = "mean", -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Loss function that measures the softmax cross entropy between 'scores' and 'labels'. This operator first computes a loss tensor whose shape is @@ -14982,7 +14978,7 @@ def top_k( axis: int = -1, largest: int = 1, sorted: int = 1, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of shape [a_0, a_1, ..., a\_{n-1}] and integer @@ -15175,7 +15171,7 @@ def unique( *, axis: Optional[int] = None, sorted: int = 1, -) -> Tuple[Var, Var, Var, Var]: +) -> tuple[Var, Var, Var, Var]: r""" Find the unique elements of a tensor. When an optional attribute 'axis' is provided, unique subtensors sliced along the 'axis' are returned. diff --git a/src/spox/opset/ai/onnx/v18.py b/src/spox/opset/ai/onnx/v18.py index 59d70884..028c0775 100644 --- a/src/spox/opset/ai/onnx/v18.py +++ b/src/spox/opset/ai/onnx/v18.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import ( - Iterable, Optional, - Sequence, ) import numpy as np diff --git a/src/spox/opset/ai/onnx/v19.py b/src/spox/opset/ai/onnx/v19.py index 8de500db..6c14823d 100644 --- a/src/spox/opset/ai/onnx/v19.py +++ b/src/spox/opset/ai/onnx/v19.py @@ -2,14 +2,11 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import ( Callable, - Dict, - Iterable, - List, Optional, - Sequence, ) from typing import cast as typing_cast @@ -459,7 +456,7 @@ class Attributes(BaseAttributes): class Outputs(BaseOutputs): output: Var - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: ((key, raw),) = ( (k, v.value) for k, v in self.attrs.get_fields().items() if v is not None ) @@ -1694,7 +1691,7 @@ def loop( - V: `optional(seq(tensor(bfloat16)))`, `optional(seq(tensor(bool)))`, `optional(seq(tensor(complex128)))`, `optional(seq(tensor(complex64)))`, `optional(seq(tensor(double)))`, `optional(seq(tensor(float)))`, `optional(seq(tensor(float16)))`, `optional(seq(tensor(int16)))`, `optional(seq(tensor(int32)))`, `optional(seq(tensor(int64)))`, `optional(seq(tensor(int8)))`, `optional(seq(tensor(string)))`, `optional(seq(tensor(uint16)))`, `optional(seq(tensor(uint32)))`, `optional(seq(tensor(uint64)))`, `optional(seq(tensor(uint8)))`, `optional(tensor(bfloat16))`, `optional(tensor(bool))`, `optional(tensor(complex128))`, `optional(tensor(complex64))`, `optional(tensor(double))`, `optional(tensor(float))`, `optional(tensor(float16))`, `optional(tensor(float8e4m3fn))`, `optional(tensor(float8e4m3fnuz))`, `optional(tensor(float8e5m2))`, `optional(tensor(float8e5m2fnuz))`, `optional(tensor(int16))`, `optional(tensor(int32))`, `optional(tensor(int64))`, `optional(tensor(int8))`, `optional(tensor(string))`, `optional(tensor(uint16))`, `optional(tensor(uint32))`, `optional(tensor(uint64))`, `optional(tensor(uint8))`, `seq(tensor(bfloat16))`, `seq(tensor(bool))`, `seq(tensor(complex128))`, `seq(tensor(complex64))`, `seq(tensor(double))`, `seq(tensor(float))`, `seq(tensor(float16))`, `seq(tensor(float8e4m3fn))`, `seq(tensor(float8e4m3fnuz))`, `seq(tensor(float8e5m2))`, `seq(tensor(float8e5m2fnuz))`, `seq(tensor(int16))`, `seq(tensor(int32))`, `seq(tensor(int64))`, `seq(tensor(int8))`, `seq(tensor(string))`, `seq(tensor(uint16))`, `seq(tensor(uint32))`, `seq(tensor(uint64))`, `seq(tensor(uint8))`, `tensor(bfloat16)`, `tensor(bool)`, `tensor(complex128)`, `tensor(complex64)`, `tensor(double)`, `tensor(float)`, `tensor(float16)`, `tensor(float8e4m3fn)`, `tensor(float8e4m3fnuz)`, `tensor(float8e5m2)`, `tensor(float8e5m2fnuz)`, `tensor(int16)`, `tensor(int32)`, `tensor(int64)`, `tensor(int8)`, `tensor(string)`, `tensor(uint16)`, `tensor(uint32)`, `tensor(uint64)`, `tensor(uint8)` """ _body_subgraph: Graph = subgraph( - typing_cast(List[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + typing_cast(list[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + [var.unwrap_type() for var in v_initial], body, ) diff --git a/src/spox/opset/ai/onnx/v20.py b/src/spox/opset/ai/onnx/v20.py index 10b95b5a..fa5a4c42 100644 --- a/src/spox/opset/ai/onnx/v20.py +++ b/src/spox/opset/ai/onnx/v20.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import ( Optional, - Tuple, ) import numpy as np @@ -1443,7 +1442,7 @@ def string_split( *, delimiter: Optional[str] = None, maxsplit: Optional[int] = None, -) -> Tuple[Var, Var]: +) -> tuple[Var, Var]: r""" StringSplit splits a string tensor's elements into substrings based on a delimiter attribute and a maxsplit attribute. diff --git a/src/spox/opset/ai/onnx/v21.py b/src/spox/opset/ai/onnx/v21.py index 72839830..f4f027cc 100644 --- a/src/spox/opset/ai/onnx/v21.py +++ b/src/spox/opset/ai/onnx/v21.py @@ -2,14 +2,11 @@ # SPDX-License-Identifier: BSD-3-Clause # ruff: noqa: E741 -- Allow ambiguous variable name +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import ( Callable, - Dict, - Iterable, - List, Optional, - Sequence, ) from typing import cast as typing_cast @@ -439,7 +436,7 @@ class Attributes(BaseAttributes): class Outputs(BaseOutputs): output: Var - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: ((key, raw),) = ( (k, v.value) for k, v in self.attrs.get_fields().items() if v is not None ) @@ -1647,7 +1644,7 @@ def loop( - V: `optional(seq(tensor(bfloat16)))`, `optional(seq(tensor(bool)))`, `optional(seq(tensor(complex128)))`, `optional(seq(tensor(complex64)))`, `optional(seq(tensor(double)))`, `optional(seq(tensor(float)))`, `optional(seq(tensor(float16)))`, `optional(seq(tensor(int16)))`, `optional(seq(tensor(int32)))`, `optional(seq(tensor(int64)))`, `optional(seq(tensor(int8)))`, `optional(seq(tensor(string)))`, `optional(seq(tensor(uint16)))`, `optional(seq(tensor(uint32)))`, `optional(seq(tensor(uint64)))`, `optional(seq(tensor(uint8)))`, `optional(tensor(bfloat16))`, `optional(tensor(bool))`, `optional(tensor(complex128))`, `optional(tensor(complex64))`, `optional(tensor(double))`, `optional(tensor(float))`, `optional(tensor(float16))`, `optional(tensor(float8e4m3fn))`, `optional(tensor(float8e4m3fnuz))`, `optional(tensor(float8e5m2))`, `optional(tensor(float8e5m2fnuz))`, `optional(tensor(int16))`, `optional(tensor(int32))`, `optional(tensor(int4))`, `optional(tensor(int64))`, `optional(tensor(int8))`, `optional(tensor(string))`, `optional(tensor(uint16))`, `optional(tensor(uint32))`, `optional(tensor(uint4))`, `optional(tensor(uint64))`, `optional(tensor(uint8))`, `seq(tensor(bfloat16))`, `seq(tensor(bool))`, `seq(tensor(complex128))`, `seq(tensor(complex64))`, `seq(tensor(double))`, `seq(tensor(float))`, `seq(tensor(float16))`, `seq(tensor(float8e4m3fn))`, `seq(tensor(float8e4m3fnuz))`, `seq(tensor(float8e5m2))`, `seq(tensor(float8e5m2fnuz))`, `seq(tensor(int16))`, `seq(tensor(int32))`, `seq(tensor(int4))`, `seq(tensor(int64))`, `seq(tensor(int8))`, `seq(tensor(string))`, `seq(tensor(uint16))`, `seq(tensor(uint32))`, `seq(tensor(uint4))`, `seq(tensor(uint64))`, `seq(tensor(uint8))`, `tensor(bfloat16)`, `tensor(bool)`, `tensor(complex128)`, `tensor(complex64)`, `tensor(double)`, `tensor(float)`, `tensor(float16)`, `tensor(float8e4m3fn)`, `tensor(float8e4m3fnuz)`, `tensor(float8e5m2)`, `tensor(float8e5m2fnuz)`, `tensor(int16)`, `tensor(int32)`, `tensor(int4)`, `tensor(int64)`, `tensor(int8)`, `tensor(string)`, `tensor(uint16)`, `tensor(uint32)`, `tensor(uint4)`, `tensor(uint64)`, `tensor(uint8)` """ _body_subgraph: Graph = subgraph( - typing_cast(List[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + typing_cast(list[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))]) + [var.unwrap_type() for var in v_initial], body, ) diff --git a/tests/conftest.py b/tests/conftest.py index d0275148..cb294166 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import sys -from typing import Dict, Optional +from typing import Optional import numpy as np import onnxruntime @@ -19,7 +19,7 @@ class ONNXRuntimeHelper: - _build_cache: Dict[Graph, bytes] + _build_cache: dict[Graph, bytes] _last_graph: Optional[Graph] _last_session: Optional[onnxruntime.InferenceSession] diff --git a/tests/full/conftest.py b/tests/full/conftest.py index 28be8733..a2ae005d 100644 --- a/tests/full/conftest.py +++ b/tests/full/conftest.py @@ -1,8 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple - import numpy as np import pytest @@ -68,7 +66,7 @@ def empty(s: Var) -> Var: def match_brackets(ext, xs: Var) -> Var: def bracket_matcher_step( i: Var, _cond: Var, stack: Var, result: Var, _: Var - ) -> Tuple[Var, Var, Var, Var]: + ) -> tuple[Var, Var, Var, Var]: closing = op.less(ext.at(xs, i), op.const(0)) ignore = op.equal(ext.at(xs, i), op.const(0)) pair = op.concat([ext.top(stack), i], axis=-1) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 693e9a10..3b0884ad 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -1,8 +1,8 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable import numpy as np import onnx diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 3ab9fd96..6d975185 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -1,8 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple - import pytest from spox._shape import Shape @@ -46,16 +44,16 @@ def no_broadcast_shapes(request): return tuple(Shape.from_simple(sh) for sh in request.param) -def test_can_broadcast_true(broadcast_shapes: Tuple[Shape, Shape, Shape]): +def test_can_broadcast_true(broadcast_shapes: tuple[Shape, Shape, Shape]): first, second, _ = broadcast_shapes assert first.can_broadcast(second) -def test_can_broadcast_false(no_broadcast_shapes: Tuple[Shape, Shape]): +def test_can_broadcast_false(no_broadcast_shapes: tuple[Shape, Shape]): first, second = no_broadcast_shapes assert not first.can_broadcast(second) -def test_broadcast(broadcast_shapes: Tuple[Shape, Shape, Shape]): +def test_broadcast(broadcast_shapes: tuple[Shape, Shape, Shape]): first, second, result = broadcast_shapes assert first.broadcast(second) == result diff --git a/tests/test_custom_operator.py b/tests/test_custom_operator.py index b8a940cb..1c3c195c 100644 --- a/tests/test_custom_operator.py +++ b/tests/test_custom_operator.py @@ -10,7 +10,6 @@ """ from dataclasses import dataclass -from typing import Dict import numpy as np @@ -44,7 +43,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: # This is technically optional, but using an operator without type inference may be inconvenient. if self.inputs.X.type is None: return {} @@ -55,7 +54,7 @@ def infer_output_types(self) -> Dict[str, Type]: ) return {"Y": t} - def propagate_values(self) -> Dict[str, np.ndarray]: + def propagate_values(self) -> dict[str, np.ndarray]: # This is optional and implements value propagation ('partial data propagation' in ONNX). # In essence constant folding carried through for purposes of type inference. return ( diff --git a/tests/test_equiv_types.py b/tests/test_equiv_types.py index 0f157308..ccf3b4a9 100644 --- a/tests/test_equiv_types.py +++ b/tests/test_equiv_types.py @@ -1,8 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import List - import numpy as np import pytest @@ -17,7 +15,7 @@ [("x", "y", "z"), ("x", "y", None), (None, None, None), None], ] ) -def shape_clique(request) -> List[Shape]: +def shape_clique(request) -> list[Shape]: return [Shape.from_simple(sh) for sh in request.param] @@ -26,7 +24,7 @@ def shape_clique(request) -> List[Shape]: [Tensor(np.int32, (2, 3, 4)), Tensor(np.int32, None), Type()], ] ) -def weak_type_clique(request) -> List[Type]: +def weak_type_clique(request) -> list[Type]: return request.param diff --git a/tests/test_function.py b/tests/test_function.py index 443146ee..fd03d1b1 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -3,7 +3,7 @@ import functools from dataclasses import dataclass -from typing import Dict, List, Union +from typing import Union import numpy as np import onnx @@ -44,7 +44,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def constructor(self, attrs: Dict[str, Attr], inputs: Inputs) -> Outputs: + def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs: # FIXME: At some point, attribute references should be properly type-hinted. a = op.constant( value_float=_Ref( @@ -95,7 +95,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def constructor(self, attrs: Dict[str, Attr], inputs: Inputs) -> Outputs: + def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs: return self.Outputs( linear( inputs.X, @@ -141,7 +141,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def constructor(self, attrs: Dict[str, Attr], inputs: Inputs) -> Outputs: + def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs: x = inputs.X a = op.mul( linear( @@ -234,7 +234,7 @@ def isnan_graph(): ) @to_function("IsNaN", "spox.test") - def isnan(v: Var) -> List[Var]: + def isnan(v: Var) -> list[Var]: return [op.not_(op.equal(v, v))] return results( diff --git a/tests/test_initializer.py b/tests/test_initializer.py index 80907b87..7a53dde0 100644 --- a/tests/test_initializer.py +++ b/tests/test_initializer.py @@ -2,14 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools -from typing import Any, List +from typing import Any import numpy as np import pytest from spox._future import initializer -TESTED_INITIALIZER_ROWS: List[List[Any]] = [ +TESTED_INITIALIZER_ROWS: list[list[Any]] = [ [0, 1, 2], [0.0, 1.0, 2.0], [np.float16(3.14), np.float16(5.3)], diff --git a/tests/test_inline.py b/tests/test_inline.py index 3252bb00..91ffb5ce 100644 --- a/tests/test_inline.py +++ b/tests/test_inline.py @@ -1,8 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Dict - import numpy as np import onnx import onnx.parser @@ -328,7 +326,7 @@ def _duplicate_subgraphs_to_list( def test_subgraph_list_rename(relu_proto): # This is a simple property test that ensures renaming # in lists of subgraphs is the same as in just subgraphs - renames: Dict[str, str] = {} + renames: dict[str, str] = {} def example_rename(n: str) -> str: if n not in renames: diff --git a/tests/test_type_translation.py b/tests/test_type_translation.py index eb952d5c..46abee7b 100644 --- a/tests/test_type_translation.py +++ b/tests/test_type_translation.py @@ -1,8 +1,6 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import List, Tuple - import numpy as np import onnx import pytest @@ -42,7 +40,7 @@ def tensor_shape_pairs(): @pytest.fixture -def type_pairs() -> List[Tuple[Type, onnx.TypeProto]]: +def type_pairs() -> list[tuple[Type, onnx.TypeProto]]: tensor_f32 = tensor_type_proto(np.float32, None) seq_tensor_f32 = onnx.helper.make_sequence_type_proto(tensor_f32) opt_seq_tensor_f32 = onnx.helper.make_optional_type_proto(seq_tensor_f32) diff --git a/tools/generate_opset.py b/tools/generate_opset.py index 1abaa52c..399a311f 100644 --- a/tools/generate_opset.py +++ b/tools/generate_opset.py @@ -3,9 +3,10 @@ import re import subprocess +from collections.abc import Iterable, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union +from typing import Optional, Union import jinja2 import onnx @@ -69,7 +70,7 @@ IF16_SUBGRAPH_SOLUTION = {"else_branch": "()", "then_branch": "()"} LOOP16_SUBGRAPH_SOLUTION = { - "body": "typing_cast(List[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))])" + "body": "typing_cast(list[Type], [Tensor(np.int64, (1,)), Tensor(np.bool_, (1,))])" "+ [var.unwrap_type() for var in v_initial]" } SCAN16_SUBGRAPH_SOLUTION = { @@ -157,9 +158,9 @@ def attr_constructor(self) -> str: def get_attributes( schema: onnx.defs.OpSchema, attr_type_overrides, - subgraph_solutions: Dict[str, str], + subgraph_solutions: dict[str, str], allow_extra: bool, -) -> List[Attribute]: +) -> list[Attribute]: out = [] for name, attr in schema.attributes.items(): default = _get_default_value(attr, attr_type_overrides) @@ -244,7 +245,7 @@ def get_constructor_return(schema: onnx.defs.OpSchema) -> str: if not schema.outputs: return "None" if len(schema.outputs) > 1: - return f"Tuple[{', '.join('Sequence[Var]' if is_variadic(out) else 'Var' for out in schema.outputs)}]" + return f"tuple[{', '.join('Sequence[Var]' if is_variadic(out) else 'Var' for out in schema.outputs)}]" (out,) = schema.outputs if is_variadic(out): return "Sequence[Var]" @@ -252,7 +253,7 @@ def get_constructor_return(schema: onnx.defs.OpSchema) -> str: _PANDOC_SEP = "\U0001f6a7" # U+1F6A7 CONSTRUCTION SIGN -_PANDOC_GFM_TO_RST_CACHE: Dict[str, str] = {} +_PANDOC_GFM_TO_RST_CACHE: dict[str, str] = {} def _pandoc_run(text: str): @@ -261,7 +262,7 @@ def _pandoc_run(text: str): ).stdout.decode() -def _pandoc_gfm_to_rst_run(*args: str) -> Tuple[str, ...]: +def _pandoc_gfm_to_rst_run(*args: str) -> tuple[str, ...]: if not args: return () @@ -278,7 +279,7 @@ def _pandoc_gfm_to_rst_run(*args: str) -> Tuple[str, ...]: return results -def _pandoc_gfm_to_rst(*args: str) -> Tuple[str, ...]: +def _pandoc_gfm_to_rst(*args: str) -> tuple[str, ...]: args = tuple(arg.strip() for arg in args) if any(_PANDOC_SEP in arg for arg in args): raise ValueError( @@ -290,7 +291,7 @@ def _pandoc_gfm_to_rst(*args: str) -> Tuple[str, ...]: if not (arg in _PANDOC_GFM_TO_RST_CACHE or not arg) ] results = _pandoc_gfm_to_rst_run(*[args[i] for i in valid]) - sub: List[Optional[str]] = [None] * len(args) + sub: list[Optional[str]] = [None] * len(args) for i, result in zip(valid, results): sub[i] = result for i, arg in enumerate(args): @@ -308,7 +309,7 @@ def pandoc_gfm_to_rst(doc: str) -> str: return result -def format_github_markdown(doc: str, *, to_batch: Optional[List[str]] = None) -> str: +def format_github_markdown(doc: str, *, to_batch: Optional[list[str]] = None) -> str: """Jinja filter. Makes some attempt at fixing "Markdown" into RST.""" # Sometimes Tensor is used in the docs (~17 instances at 1.13) # and is treated as invalid HTML tags by pandoc. @@ -358,14 +359,14 @@ def get_env(): def write_schemas_code( file, domain: str, - schemas: List[onnx.defs.OpSchema], - type_inference: Dict[str, str], - value_propagation: Dict[str, str], - out_variadic_solutions: Dict[str, str], - subgraphs_solutions: Dict[str, Dict[str, str]], - attr_type_overrides: List[Tuple[Optional[str], str, Tuple[str, str]]], - allow_extra_constructor_arguments: Set[str], - inherited_schemas: Dict[onnx.defs.OpSchema, str], + schemas: list[onnx.defs.OpSchema], + type_inference: dict[str, str], + value_propagation: dict[str, str], + out_variadic_solutions: dict[str, str], + subgraphs_solutions: dict[str, dict[str, str]], + attr_type_overrides: list[tuple[Optional[str], str, tuple[str, str]]], + allow_extra_constructor_arguments: set[str], + inherited_schemas: dict[onnx.defs.OpSchema, str], extras: Sequence[str], gen_docstrings: bool, ): @@ -390,9 +391,9 @@ def write_schemas_code( end="\n", ) - built_schemas: Set[onnx.defs.OpSchema] = set() + built_schemas: set[onnx.defs.OpSchema] = set() - pandoc_batch: List[str] = [] + pandoc_batch: list[str] = [] for schema in schemas: if schema in inherited_schemas: continue @@ -527,20 +528,20 @@ def run_pre_commit_hooks(filenames: Union[str, Iterable[str]]): def main( domain: str, version: Optional[int] = None, - type_inference: Optional[Dict[str, str]] = None, - value_propagation: Optional[Dict[str, str]] = None, - out_variadic_solutions: Optional[Dict[str, str]] = None, - subgraphs_solutions: Optional[Dict[str, Dict[str, str]]] = None, + type_inference: Optional[dict[str, str]] = None, + value_propagation: Optional[dict[str, str]] = None, + out_variadic_solutions: Optional[dict[str, str]] = None, + subgraphs_solutions: Optional[dict[str, dict[str, str]]] = None, attr_type_overrides: Optional[ - List[Tuple[Optional[str], str, Tuple[str, str]]] + list[tuple[Optional[str], str, tuple[str, str]]] ] = None, allow_extra_constructor_arguments: Iterable[str] = (), - inherited_schemas: Optional[Dict[onnx.defs.OpSchema, str]] = None, + inherited_schemas: Optional[dict[onnx.defs.OpSchema, str]] = None, extras: Sequence[str] = (), target: str = "src/spox/opset/", pre_commit_hooks: bool = True, gen_docstrings: bool = True, -) -> Tuple[List[onnx.defs.OpSchema], str]: +) -> tuple[list[onnx.defs.OpSchema], str]: """ Generate opset module code and save it in a `.py` source code file. diff --git a/tools/templates/class.jinja2 b/tools/templates/class.jinja2 index 5300709f..b2553675 100644 --- a/tools/templates/class.jinja2 +++ b/tools/templates/class.jinja2 @@ -6,7 +6,7 @@ class _{{ schema.name }}(StandardNode): {% else %}{% for attr in attributes %} {{ attr.name }}: {{ attr.member_type }} {% endfor %} - {% endif %} + {% endif %} {% if schema.inputs %} @dataclass @@ -47,14 +47,14 @@ class _{{ schema.name }}(StandardNode): {% endif %} {% if type_inference %} - def infer_output_types(self) -> Dict[str, Type]: + def infer_output_types(self) -> dict[str, Type]: {% filter indent(width=8) %} {%+ include type_inference %} {% endfilter %} {% endif %} {% if value_propagation %} - def propagate_values(self) -> Dict[str, PropValueType]: + def propagate_values(self) -> dict[str, PropValueType]: {% filter indent(width=8) %} {%+ include value_propagation %} {% endfilter %} @@ -66,4 +66,3 @@ class _{{ schema.name }}(StandardNode): inputs: {% if schema.inputs %}Inputs{% else %}BaseInputs{% endif %} outputs: {% if schema.outputs %}Outputs{% else %}BaseOutputs{% endif %} - diff --git a/tools/templates/extras/promote.jinja2 b/tools/templates/extras/promote.jinja2 index 69b28422..332229eb 100644 --- a/tools/templates/extras/promote.jinja2 +++ b/tools/templates/extras/promote.jinja2 @@ -1,6 +1,6 @@ def promote( *types: Union[Var, np.generic, int, float, None] -) -> Tuple[Optional[Var], ...]: +) -> tuple[Optional[Var], ...]: """ Apply constant promotion and type promotion to given parameters, creating constants and/or casting. diff --git a/tools/templates/preamble.jinja2 b/tools/templates/preamble.jinja2 index 44f8ba6b..e4e320f3 100644 --- a/tools/templates/preamble.jinja2 +++ b/tools/templates/preamble.jinja2 @@ -5,12 +5,9 @@ from dataclasses import dataclass from typing import ( Any, Callable, - Dict, Iterable, - List, Optional, Sequence, - Tuple, Union, ) from typing import cast as typing_cast