Skip to content

Commit

Permalink
Enable pyupgrade (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Oct 15, 2024
1 parent 411f9af commit bb337b7
Show file tree
Hide file tree
Showing 42 changed files with 256 additions and 287 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespaces = false

[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I"]
extend-select = ["I", "UP"]

[tool.ruff.lint.isort]
known-first-party = ["spox"]
Expand Down
24 changes: 12 additions & 12 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from typing import Dict, List, Optional
from typing import Optional

import numpy as np
import onnx
Expand All @@ -23,8 +23,8 @@ def adapt_node(
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: Dict[Var, str],
) -> Optional[List[onnx.NodeProto]]:
var_names: dict[Var, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None

Expand Down Expand Up @@ -69,11 +69,11 @@ def adapt_node(

def adapt_inline(
node: _Inline,
protos: List[onnx.NodeProto],
target_opsets: Dict[str, int],
var_names: Dict[Var, str],
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[Var, str],
node_name: str,
) -> List[onnx.NodeProto]:
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
target_version = target_opsets[""]

Expand All @@ -97,11 +97,11 @@ def adapt_inline(

def adapt_best_effort(
node: Node,
protos: List[onnx.NodeProto],
opsets: Dict[str, int],
var_names: Dict[Var, str],
node_names: Dict[Node, str],
) -> Optional[List[onnx.NodeProto]]:
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[Var, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
return adapt_inline(
node,
Expand Down
11 changes: 6 additions & 5 deletions src/spox/_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import abc
from abc import ABC
from typing import Any, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union
from collections.abc import Iterable
from typing import Any, Generic, Optional, TypeVar, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -45,7 +46,7 @@ def deref(self) -> "Attr":
return self

@classmethod
def maybe(cls: Type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]:
def maybe(cls: type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]:
return cls(value, name) if value is not None else None

@property
Expand Down Expand Up @@ -200,15 +201,15 @@ def _to_onnx_deref(self) -> AttributeProto:
)


class _AttrIterable(Attr[Tuple[S, ...]], ABC):
def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]], name: str):
class _AttrIterable(Attr[tuple[S, ...]], ABC):
def __init__(self, value: Union[Iterable[S], _Ref[tuple[S, ...]]], name: str):
super().__init__(
value=value if isinstance(value, _Ref) else tuple(value), name=name
)

@classmethod
def maybe(
cls: Type[AttrIterableT],
cls: type[AttrIterableT],
value: Optional[Iterable[S]],
name: str,
) -> Optional[AttrIterableT]:
Expand Down
54 changes: 25 additions & 29 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -61,12 +57,12 @@ class BuildResult:
"""

scope: Scope
nodes: Dict[Node, Tuple[onnx.NodeProto, ...]]
arguments: Tuple[Var, ...]
results: Tuple[Var, ...]
opset_req: Set[Tuple[str, int]]
functions: Tuple["_function.Function", ...]
initializers: Dict[Var, np.ndarray]
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[Var, ...]
results: tuple[Var, ...]
opset_req: set[tuple[str, int]]
functions: tuple["_function.Function", ...]
initializers: dict[Var, np.ndarray]


class Builder:
Expand Down Expand Up @@ -125,8 +121,8 @@ class ScopeTree:
(lowest common ancestor), which is a common operation on trees.
"""

subgraph_owner: Dict["Graph", Node]
scope_of: Dict[Node, "Graph"]
subgraph_owner: dict["Graph", Node]
scope_of: dict[Node, "Graph"]

def __init__(self):
self.subgraph_owner = {}
Expand Down Expand Up @@ -165,18 +161,18 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph":

# Graphs needed in the build
main: "Graph"
graphs: Set["Graph"]
graph_topo: List["Graph"]
graphs: set["Graph"]
graph_topo: list["Graph"]
# Arguments, results
arguments_of: Dict["Graph", List[Var]]
results_of: Dict["Graph", List[Var]]
source_of: Dict["Graph", Node]
arguments_of: dict["Graph", list[Var]]
results_of: dict["Graph", list[Var]]
source_of: dict["Graph", Node]
# Arguments found by traversal
all_arguments_in: Dict["Graph", Set[Var]]
claimed_arguments_in: Dict["Graph", Set[Var]]
all_arguments_in: dict["Graph", set[Var]]
claimed_arguments_in: dict["Graph", set[Var]]
# Scopes
scope_tree: ScopeTree
scope_own: Dict["Graph", List[Node]]
scope_own: dict["Graph", list[Node]]

def __init__(self, main: "Graph"):
self.main = main
Expand Down Expand Up @@ -207,8 +203,8 @@ def build_main(self) -> BuildResult:

@staticmethod
def get_intro_results(
request_results: Dict[str, Var], set_names: bool
) -> List[Var]:
request_results: dict[str, Var], set_names: bool
) -> list[Var]:
"""
Helper method for wrapping all requested results into a single Introduce and possibly naming them.
Expand All @@ -222,7 +218,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: "Graph") -> Tuple[Set[Var], Set[Var]]:
def discover(self, graph: "Graph") -> tuple[set[Var], set[Var]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand Down Expand Up @@ -359,7 +355,7 @@ def resolve_scopes(self) -> None:
- this is slightly higher quality than a normal topological sorting which attempts to be "parallel",
while a DFS' postorder is more "localised".
"""
graph_scope_set: Dict[Any, Set[Node]] = {ctx: set() for ctx in self.graphs}
graph_scope_set: dict[Any, set[Node]] = {ctx: set() for ctx in self.graphs}
for node, owner in self.scope_tree.scope_of.items():
graph_scope_set[owner].add(node)

Expand Down Expand Up @@ -392,7 +388,7 @@ def resolve_scopes(self) -> None:

def get_build_subgraph_callback(
self, scope: Scope
) -> Tuple[Callable, Set[Tuple[str, int]]]:
) -> tuple[Callable, set[tuple[str, int]]]:
"""Create a callback for building subgraphs for ``Node.to_onnx``."""

subgraph_opset_req = set() # Keeps track of all opset imports in subgraphs
Expand Down Expand Up @@ -432,11 +428,11 @@ def compile_graph(
See the definition for the exact contents of the BuildResult dataclass. Used to build GraphProto/ModelProto
from a Spox Graph.
"""
nodes: Dict[Node, Tuple[onnx.NodeProto, ...]] = {}
nodes: dict[Node, tuple[onnx.NodeProto, ...]] = {}
# A bunch of model metadata we're collecting
opset_req: Set[Tuple[str, int]] = set()
functions: List[_function.Function] = []
initializers: Dict[Var, np.ndarray] = {}
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[Var, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
11 changes: 6 additions & 5 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import dataclasses
import enum
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Union

from ._attributes import Attr
from ._var import Var
Expand All @@ -17,7 +18,7 @@ class BaseFields:

@dataclass
class BaseAttributes(BaseFields):
def get_fields(self) -> Dict[str, Union[None, Attr]]:
def get_fields(self) -> dict[str, Union[None, Attr]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand Down Expand Up @@ -68,7 +69,7 @@ def _get_field_type(cls, field) -> VarFieldKind:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[Tuple[str, Optional[Var]]]:
def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, Var):
Expand All @@ -84,11 +85,11 @@ def __len__(self) -> int:
"""Count the number of fields in this object (should be same as declared in the class)."""
return sum(1 for _ in self)

def get_vars(self) -> Dict[str, Var]:
def get_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the Vars in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> Dict[str, Union[None, Var, Sequence[Var]]]:
def get_fields(self) -> dict[str, Union[None, Var, Sequence[Var]]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand Down
11 changes: 6 additions & 5 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import inspect
import itertools
from collections.abc import Iterable
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Callable, Dict, Iterable, Tuple, TypeVar
from typing import TYPE_CHECKING, Callable, TypeVar

import onnx

Expand Down Expand Up @@ -41,8 +42,8 @@ class Function(_InternalNode):
via the ``to_onnx_function`` method.
"""

func_args: Dict[str, Var]
func_attrs: Dict[str, _attributes.Attr]
func_args: dict[str, Var]
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
func_graph: "_graph.Graph"
Expand All @@ -60,7 +61,7 @@ def constructor(self, attrs, inputs):
f"Function {type(self).__name__} does not implement a constructor."
)

def infer_output_types(self) -> Dict[str, Type]:
def infer_output_types(self) -> dict[str, Type]:
from . import _graph

self.func_args = _graph.arguments_dict(
Expand Down Expand Up @@ -98,7 +99,7 @@ def update_metadata(self, opset_req, initializers, functions):
functions.extend(self.func_graph._get_build_result().functions)

def to_onnx_function(
self, *, extra_opset_req: Iterable[Tuple[str, int]] = ()
self, *, extra_opset_req: Iterable[tuple[str, int]] = ()
) -> onnx.FunctionProto:
"""
Translate self into an ONNX FunctionProto, based on the ``func_*`` attributes set when this operator
Expand Down
5 changes: 3 additions & 2 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

"""Module containing experimental Spox features that may be standard in the future."""

from collections.abc import Iterable
from contextlib import contextmanager
from typing import Iterable, List, Optional, Union
from typing import Optional, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -84,7 +85,7 @@ def _promote(
Apply constant promotion and type promotion to given parameters,
creating constants and/or casting.
"""
targets: List[Union[np.dtype, np.generic, int, float]] = [
targets: list[Union[np.dtype, np.generic, int, float]] = [
x.type.dtype if isinstance(x, Var) and isinstance(x.type, Tensor) else x # type: ignore
for x in args
]
Expand Down
Loading

0 comments on commit bb337b7

Please sign in to comment.