Skip to content

Commit

Permalink
Fix passing
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 5, 2024
1 parent f5af5d9 commit 27c562e
Show file tree
Hide file tree
Showing 13 changed files with 22,964 additions and 17,779 deletions.
19 changes: 11 additions & 8 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing_extensions import Self
from typing import Any, Generic, Optional, TypeVar, Union

from ._attributes import Attr
from ._exceptions import InferenceWarning
Expand Down Expand Up @@ -158,13 +157,16 @@ def fully_typed(self) -> bool:
)


TypeBaseVars = TypeVar("TypeBaseVars", bound=BaseVars)


@dataclass
class BaseInputs(BaseVarInfos, metaclass=BaseVarsMeta):
class BaseInputs(BaseVarInfos, Generic[TypeBaseVars], metaclass=BaseVarsMeta):
@dataclass
class Vars(BaseVars):
pass

def vars(self, prop_values) -> Vars:
def vars(self, prop_values) -> TypeBaseVars:
vars_structure: dict[str, Union[Var, Sequence[Var]]] = {}

for field in dataclasses.fields(self):
Expand All @@ -189,10 +191,11 @@ def vars(self, prop_values) -> Vars:

vars_structure[field.name] = vars

return self.Vars(**vars_structure)
return self.__class__.Vars(**vars_structure) # type: ignore


@dataclass
class BaseOutputs(BaseVarInfos, metaclass=BaseVarsMeta):
class BaseOutputs(BaseVarInfos, Generic[TypeBaseVars], metaclass=BaseVarsMeta):
@dataclass
class Vars(BaseVars):
pass
Expand All @@ -201,7 +204,7 @@ def _propagate_vars(
self,
prop_values={},
flatten_variadic=False,
):
) -> TypeBaseVars:
def _create_var(key, var_info):
ret = Var(var_info, None)

Expand Down Expand Up @@ -234,4 +237,4 @@ def _create_var(key, var_info):
_create_var(f"{key}_{i}", v) for i, v in enumerate(var_info)
]

return ret_dict
return self.__class__.Vars(**ret_dict) # type: ignore
6 changes: 4 additions & 2 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def infer_output_types(self, initializers={}) -> dict[str, Type]:
self.func_inputs = self.Inputs(**self.func_args) # type: ignore
self.func_outputs = self.constructor(self.func_attrs, self.func_inputs)
self.func_graph = _graph.results(
**self.func_outputs._propagate_vars()
**self.func_outputs._propagate_vars(initializers).get_vars()
).with_arguments(*func_args_var.values())

return {
Expand Down Expand Up @@ -147,7 +147,9 @@ class Attributes(BaseAttributes):
op_type = OpType(name, domain, version)

def constructor(self, attrs, inputs):
return self.Outputs(*unwrap_vars(fun(*wrap_vars(inputs.get_fields().values()))))
return self.Outputs(
*unwrap_vars(fun(*wrap_vars(inputs.get_fields().values())))
)

return _Func

Expand Down
4 changes: 2 additions & 2 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class Attributes(BaseAttributes):
default: Optional[AttrTensor] = None

@dataclass
class Inputs(BaseInputs):
class Inputs(BaseInputs["Argument.Outputs.Vars"]):
pass

@dataclass
class Outputs(BaseOutputs):
class Outputs(BaseOutputs["Argument.Outputs.Vars"]):
arg: VarInfo

attrs: Attributes
Expand Down
10 changes: 5 additions & 5 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def infer_output_types(self, initializers) -> dict[str, Type]:
def inference(self, infer_types: bool = True, initializers={}):
# Type inference routine - call infer_output_types if required
# and check if it provides the expected outputs.
out_types = self.infer_output_types(initializers=initializers) if infer_types else {}
out_types = (
self.infer_output_types(initializers=initializers) if infer_types else {}
)

for key, var in self.outputs.get_vars().items():
if var.type is None: # If no existing type from init_output_vars
Expand All @@ -236,10 +238,8 @@ def inference(self, infer_types: bool = True, initializers={}):
def get_output_vars(self, flatten_variadic=False, **initializers):
# After typing everything, try to get values for outputs
out_values = self.propagate_values(initializers)
return type(self.outputs).Vars(
**self.outputs._propagate_vars(
out_values, flatten_variadic=flatten_variadic
)
return self.outputs._propagate_vars(
out_values, flatten_variadic=flatten_variadic
)

def validate_types(self) -> None:
Expand Down
34 changes: 27 additions & 7 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

"""Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference."""

from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable

import onnx
from onnx.numpy_helper import from_array
import onnx.reference
import onnx.shape_inference
from onnx.defs import OpSchema
import numpy as np

from . import _value_prop
from ._exceptions import InferenceError
Expand All @@ -19,8 +18,8 @@
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Optional, Sequence, Tensor, Type
from ._value_prop import PropValueType
from ._utils import from_array
from ._value_prop import PropValue, PropValueType

if TYPE_CHECKING:
from ._graph import Graph
Expand Down Expand Up @@ -51,7 +50,11 @@ def min_output(self) -> int:
return self.schema.min_output

def to_singleton_onnx_model(
self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True, prop_values={}
self,
*,
dummy_outputs: bool = True,
with_dummy_subgraphs: bool = True,
prop_values={},
) -> tuple[onnx.ModelProto, Scope]:
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Expand Down Expand Up @@ -100,11 +103,26 @@ def out_value_info(curr_key, curr_var):
]
# Initializers, passed in to allow partial data propagation
# - used so that operators like Reshape are aware of constant shapes
initializers = [
# TODO: fix this
initializers_from_array = [
from_array(prop.value, name) # type: ignore
for name, prop in prop_values.items()
if prop is not None and isinstance(prop.value, np.ndarray)
if isinstance(prop, PropValue)
and prop.value is not None
and not isinstance(prop.type, Sequence)
]

initializers_from_sequence = [
from_array(prop.value, f"{name}_{i}") # type: ignore
for name, prop_list in prop_values.items()
if isinstance(prop_list, list)
for i, prop in enumerate(prop_list)
if prop is not None and not isinstance(prop.value, Iterable)
]

initializers = initializers_from_array
initializers.extend(initializers_from_sequence)

# Graph and model
graph = onnx.helper.make_graph(
[node_proto],
Expand Down Expand Up @@ -168,7 +186,9 @@ def propagate_values_onnx(self, initializers) -> dict[str, PropValueType]:
if next(iter(self.subgraphs), None) is not None:
# Cannot do propagation with subgraphs implicitly for performance - should be reimplemented
return {}
model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False, prop_values=initializers)
model, scope = self.to_singleton_onnx_model(
with_dummy_subgraphs=False, prop_values=initializers
)
wrap_feed, run, unwrap_feed = _value_prop.get_backend_calls()
input_feed = {
scope.var[var_info]: wrap_feed(initializers[name])
Expand Down
Loading

0 comments on commit 27c562e

Please sign in to comment.