Skip to content

Commit

Permalink
Hint that _VarInfo is private
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 20, 2024
1 parent 1aebc1a commit 67d5b3b
Show file tree
Hide file tree
Showing 23 changed files with 871 additions and 871 deletions.
8 changes: 4 additions & 4 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from ._schemas import SCHEMAS
from ._scope import Scope
from ._utils import from_array
from ._var import VarInfo
from ._var import _VarInfo


def adapt_node(
node: Node,
proto: onnx.NodeProto,
source_version: int,
target_version: int,
var_names: dict[VarInfo, str],
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
if source_version == target_version:
return None
Expand Down Expand Up @@ -70,7 +70,7 @@ def adapt_inline(
node: _Inline,
protos: list[onnx.NodeProto],
target_opsets: dict[str, int],
var_names: dict[VarInfo, str],
var_names: dict[_VarInfo, str],
node_name: str,
) -> list[onnx.NodeProto]:
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
Expand Down Expand Up @@ -98,7 +98,7 @@ def adapt_best_effort(
node: Node,
protos: list[onnx.NodeProto],
opsets: dict[str, int],
var_names: dict[VarInfo, str],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
if isinstance(node, _Inline):
Expand Down
20 changes: 10 additions & 10 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._node import Node
from ._scope import Scope
from ._traverse import iterative_dfs
from ._var import Var, VarInfo, unwrap_vars
from ._var import Var, _VarInfo, unwrap_vars

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -58,11 +58,11 @@ class BuildResult:

scope: Scope
nodes: dict[Node, tuple[onnx.NodeProto, ...]]
arguments: tuple[VarInfo, ...]
results: tuple[VarInfo, ...]
arguments: tuple[_VarInfo, ...]
results: tuple[_VarInfo, ...]
opset_req: set[tuple[str, int]]
functions: tuple["_function.Function", ...]
initializers: dict[VarInfo, np.ndarray]
initializers: dict[_VarInfo, np.ndarray]


class Builder:
Expand Down Expand Up @@ -164,12 +164,12 @@ def lca(self, a: "Graph", b: "Graph") -> "Graph":
graphs: set["Graph"]
graph_topo: list["Graph"]
# Arguments, results
arguments_of: dict["Graph", list[VarInfo]]
results_of: dict["Graph", list[VarInfo]]
arguments_of: dict["Graph", list[_VarInfo]]
results_of: dict["Graph", list[_VarInfo]]
source_of: dict["Graph", Node]
# Arguments found by traversal
all_arguments_in: dict["Graph", set[VarInfo]]
claimed_arguments_in: dict["Graph", set[VarInfo]]
all_arguments_in: dict["Graph", set[_VarInfo]]
claimed_arguments_in: dict["Graph", set[_VarInfo]]
# Scopes
scope_tree: ScopeTree
scope_own: dict["Graph", list[Node]]
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_intro_results(
var._rename(key)
return vars

def discover(self, graph: "Graph") -> tuple[set[VarInfo], set[VarInfo]]:
def discover(self, graph: "Graph") -> tuple[set[_VarInfo], set[_VarInfo]]:
"""
Run the discovery step of the build process. Resolves arguments and results for the involved graphs.
Finds the topological ordering between (sub)graphs and sets their owners (nodes of which they are attributes).
Expand Down Expand Up @@ -432,7 +432,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: set[tuple[str, int]] = set()
functions: list[_function.Function] = []
initializers: dict[VarInfo, np.ndarray] = {}
initializers: dict[_VarInfo, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager

from spox._var import VarInfo
from spox._var import _VarInfo

# If `STORE_TRACEBACK` is `True` any node created will store a traceback for its point of creation.
STORE_TRACEBACK = False
Expand Down Expand Up @@ -36,7 +36,7 @@ def show_construction_tracebacks(debug_index):
if -1 in found:
del found[-1]
for name, obj in reversed(found.values()):
if isinstance(obj, VarInfo):
if isinstance(obj, _VarInfo):
if not obj:
continue
node = obj._op
Expand Down
26 changes: 13 additions & 13 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._attributes import Attr
from ._exceptions import InferenceWarning
from ._value_prop import PropDict, PropValue
from ._var import Var, VarInfo
from ._var import Var, _VarInfo


@dataclass
Expand Down Expand Up @@ -87,10 +87,10 @@ def __post_init__(self):
value = getattr(self, field.name)
field_type = self._get_field_type(field)
if field_type == VarFieldKind.SINGLE:
if not isinstance(value, VarInfo):
if not isinstance(value, _VarInfo):
raise TypeError(f"Field expected VarInfo, got: {type(value)}.")
elif field_type == VarFieldKind.OPTIONAL:
if value is not None and not isinstance(value, VarInfo):
if value is not None and not isinstance(value, _VarInfo):
raise TypeError(
f"Optional must be VarInfo or None, got: {type(value)}."
)
Expand All @@ -101,43 +101,43 @@ def __post_init__(self):
)
# Cast to tuple to avoid accidental mutation
setattr(self, field.name, tuple(value))
if bad := {type(var) for var in value} - {VarInfo}:
if bad := {type(var) for var in value} - {_VarInfo}:
raise TypeError(
f"Variadic field must only consist of VarInfos, got: {bad}."
)

@classmethod
def _get_field_type(cls, field) -> VarFieldKind:
"""Access the kind of the field (single, optional, variadic) based on its type annotation."""
if field.type == VarInfo:
if field.type == _VarInfo:
return VarFieldKind.SINGLE
elif field.type == Optional[VarInfo]:
elif field.type == Optional[_VarInfo]:
return VarFieldKind.OPTIONAL
elif field.type == Sequence[VarInfo]:
elif field.type == Sequence[_VarInfo]:
return VarFieldKind.VARIADIC
raise ValueError(f"Bad field type: '{field.type}'.")

def _flatten(self) -> Iterable[tuple[str, Optional[VarInfo]]]:
def _flatten(self) -> Iterable[tuple[str, Optional[_VarInfo]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, VarInfo):
if value is None or isinstance(value, _VarInfo):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def __iter__(self) -> Iterator[Optional[VarInfo]]:
def __iter__(self) -> Iterator[Optional[_VarInfo]]:
"""Iterate over the values of fields in this object."""
yield from (v for _, v in self._flatten())

def __len__(self) -> int:
"""Count the number of fields in this object (should be same as declared in the class)."""
return sum(1 for _ in self)

def get_var_infos(self) -> dict[str, VarInfo]:
def get_var_infos(self) -> dict[str, _VarInfo]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def get_fields(self) -> dict[str, Union[None, VarInfo, Sequence[VarInfo]]]:
def get_fields(self) -> dict[str, Union[None, _VarInfo, Sequence[_VarInfo]]]:
"""Return a mapping of all fields stored in this object by name."""
return self.__dict__.copy()

Expand Down Expand Up @@ -218,7 +218,7 @@ def _create_var(key, var_info):
ret_dict = {}

for key, var_info in self.__dict__.items():
if var_info is None or isinstance(var_info, VarInfo):
if var_info is None or isinstance(var_info, _VarInfo):
ret_dict[key] = _create_var(key, var_info)
else:
ret_dict[key] = [
Expand Down
10 changes: 5 additions & 5 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from ._internal_op import _InternalNode
from ._node import Node, OpType
from ._type_system import Type
from ._var import Var, VarInfo, unwrap_vars
from ._var import Var, _VarInfo, unwrap_vars

if TYPE_CHECKING:
from . import _graph

DEFAULT_FUNCTION_DOMAIN = "spox.default"

ConstructorT = TypeVar("ConstructorT", bound=Callable[..., Iterable[VarInfo]])
ConstructorT = TypeVar("ConstructorT", bound=Callable[..., Iterable[_VarInfo]])


class Function(_InternalNode):
Expand All @@ -42,7 +42,7 @@ class Function(_InternalNode):
via the ``to_onnx_function`` method.
"""

func_args: dict[str, VarInfo]
func_args: dict[str, _VarInfo]
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
Expand Down Expand Up @@ -130,12 +130,12 @@ def to_onnx_function(
def _make_function_cls(fun, num_inputs, num_outputs, domain, version, name):
_FuncInputs = make_dataclass(
"_FuncInputs",
((f"in{i}", VarInfo) for i in range(num_inputs)),
((f"in{i}", _VarInfo) for i in range(num_inputs)),
bases=(BaseInputs,),
)
_FuncOutputs = make_dataclass(
"_FuncOutputs",
((f"out{i}", VarInfo) for i in range(num_outputs)),
((f"out{i}", _VarInfo) for i in range(num_outputs)),
bases=(BaseOutputs,),
)

Expand Down
6 changes: 3 additions & 3 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ._schemas import max_opset_policy
from ._type_system import Tensor, Type
from ._utils import from_array
from ._var import Var, VarInfo
from ._var import Var, _VarInfo


def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var]:
Expand Down Expand Up @@ -222,7 +222,7 @@ def requested_results(self) -> dict[str, Var]:
"""Results (named) requested by this Graph (for building)."""
return self._results

def get_arguments(self) -> dict[str, VarInfo]:
def get_arguments(self) -> dict[str, _VarInfo]:
"""
Get the effective named arguments (after build) of this Graph.
Expand All @@ -233,7 +233,7 @@ def get_arguments(self) -> dict[str, VarInfo]:
for var in self._get_build_result().arguments
}

def get_results(self) -> dict[str, VarInfo]:
def get_results(self) -> dict[str, _VarInfo]:
"""
Get the effective named results (after build) of this Graph.
Expand Down
6 changes: 3 additions & 3 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from spox._node import OpType
from spox._scope import Scope
from spox._type_system import Type
from spox._var import VarInfo
from spox._var import _VarInfo

from . import _value_prop

Expand Down Expand Up @@ -86,11 +86,11 @@ class Attributes(BaseAttributes):

@dataclass
class Inputs(BaseInputs):
inputs: Sequence[VarInfo]
inputs: Sequence[_VarInfo]

@dataclass
class Outputs(BaseOutputs):
outputs: Sequence[VarInfo]
outputs: Sequence[_VarInfo]

op_type = OpType("Inline", "spox.internal", 0)

Expand Down
10 changes: 5 additions & 5 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._shape import SimpleShape
from ._type_system import Tensor, Type
from ._value_prop import PropDict, PropValueType
from ._var import Var, VarInfo, unwrap_vars
from ._var import Var, _VarInfo, unwrap_vars

# This is a default used for internal operators that
# require the default domain. The most common of these
Expand Down Expand Up @@ -78,7 +78,7 @@ class Inputs(BaseInputs):

@dataclass
class Outputs(BaseOutputs):
arg: VarInfo
arg: _VarInfo

attrs: Attributes
inputs: Inputs
Expand Down Expand Up @@ -115,7 +115,7 @@ class Attributes(BaseAttributes):

@dataclass
class Outputs(BaseOutputs):
arg: VarInfo
arg: _VarInfo

attrs: Attributes
inputs: BaseInputs
Expand Down Expand Up @@ -149,11 +149,11 @@ class Attributes(BaseAttributes):

@dataclass
class Inputs(BaseInputs):
inputs: Sequence[VarInfo]
inputs: Sequence[_VarInfo]

@dataclass
class Outputs(BaseOutputs):
outputs: Sequence[VarInfo]
outputs: Sequence[_VarInfo]

op_type = OpType("Introduce", "spox.internal", 0)

Expand Down
14 changes: 7 additions & 7 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, VarFieldKind
from ._type_system import Type
from ._value_prop import PropDict
from ._var import VarInfo
from ._var import _VarInfo

if typing.TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -308,28 +308,28 @@ def _init_output_vars(self) -> BaseOutputs:
(variadic,) = variadics
else:
variadic = None
outputs: dict[str, Union[VarInfo, Sequence[VarInfo]]] = {
field.name: VarInfo(self, None)
outputs: dict[str, Union[_VarInfo, Sequence[_VarInfo]]] = {
field.name: _VarInfo(self, None)
for field in dataclasses.fields(self.Outputs)
if field.name != variadic
}
if variadic is not None:
assert self.out_variadic is not None
outputs[variadic] = [VarInfo(self, None) for _ in range(self.out_variadic)]
outputs[variadic] = [_VarInfo(self, None) for _ in range(self.out_variadic)]
return self.Outputs(**outputs) # type: ignore

@property
def dependencies(self) -> Iterable[VarInfo]:
def dependencies(self) -> Iterable[_VarInfo]:
"""List of input VarInfos into this Node."""
return (var for var in self.inputs.get_var_infos().values())

@property
def dependents(self) -> Iterable[VarInfo]:
def dependents(self) -> Iterable[_VarInfo]:
"""List of output VarInfos from this Node."""
return (var for var in self.outputs.get_var_infos().values())

@property
def incident(self) -> Iterable[VarInfo]:
def incident(self) -> Iterable[_VarInfo]:
"""List of both input and output VarInfos for this Node."""
return itertools.chain(self.dependencies, self.dependents)

Expand Down
Loading

0 comments on commit 67d5b3b

Please sign in to comment.