Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 4, 2024
1 parent e29a920 commit bfaed79
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 46 deletions.
28 changes: 16 additions & 12 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._internal_op import _InternalNode
from ._node import Node, OpType
from ._type_system import Type
from ._var import Var, VarInfo
from ._var import Var, VarInfo, unwrap_vars, wrap_vars

if TYPE_CHECKING:
from . import _graph
Expand Down Expand Up @@ -42,7 +42,7 @@ class Function(_InternalNode):
via the ``to_onnx_function`` method.
"""

func_args: dict[str, Var]
func_args: dict[str, VarInfo]
func_attrs: dict[str, _attributes.Attr]
func_inputs: BaseInputs
func_outputs: BaseOutputs
Expand All @@ -64,10 +64,12 @@ def constructor(self, attrs, inputs):
def infer_output_types(self) -> dict[str, Type]:
from . import _graph

self.func_args = _graph.arguments_dict(
func_args_var = _graph.arguments_dict(
**{name: var.type for name, var in self.inputs.get_vars().items()}
)

self.func_args = unwrap_vars(func_args_var)

self.func_attrs = {}
for name, attr in self.attrs.get_fields().items():
if attr is None:
Expand All @@ -78,9 +80,9 @@ def infer_output_types(self) -> dict[str, Type]:

self.func_inputs = self.Inputs(**self.func_args) # type: ignore
self.func_outputs = self.constructor(self.func_attrs, self.func_inputs)
self.func_graph = _graph.results(**self.func_outputs.get_vars()).with_arguments(
*self.func_args.values()
)
self.func_graph = _graph.results(
**self.func_outputs._propagate_vars()
).with_arguments(*func_args_var.values())

return {
name: var.type
Expand Down Expand Up @@ -157,7 +159,7 @@ def to_function(name: str, domain: str = "spox.function", *, _version: int = 0):
The function must be deterministic in the performed operations, as otherwise an error will be raised at build
due to inconsistent function bodies.
``fun`` is assumed to take only VarInfo arguments and return an iterable of them. These will be used to generate the
``fun`` is assumed to take only Var arguments and return an iterable of them. These will be used to generate the
function class signature.
Keep in mind that functions with the same name & domain will be merged together.
Expand All @@ -172,13 +174,13 @@ def inner(fun: ConstructorT) -> ConstructorT:
_num_outputs = None
_cls = None

def get_num_outputs(*args: VarInfo) -> int:
def get_num_outputs(*args: Var) -> int:
nonlocal _num_outputs
if _num_outputs is None:
_num_outputs = sum(1 for _ in fun(*args))
return _num_outputs

def init(*args: VarInfo):
def init(*args: Var):
nonlocal _cls
if _cls is not None:
return _cls
Expand All @@ -188,10 +190,12 @@ def init(*args: VarInfo):
)
return _cls

def alt_fun(*args: VarInfo) -> Iterable[VarInfo]:
def alt_fun(*args: Var) -> Iterable[Var]:
cls = init(*args)
return (
cls(cls.Attributes(), cls.Inputs(*args)).outputs.get_fields().values()
return wrap_vars(
cls(cls.Attributes(), cls.Inputs(*unwrap_vars(args)))
.outputs.get_fields()
.values()
)

return alt_fun # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def signature(self) -> str:
"""Get a signature of this Node, including its inputs and attributes (but not outputs)."""

def fmt_input(key, var):
return f"{key}: {var.type}" + (
f" = {var._value}" if var._value is not None else ""
)
return f"{key}: {var.type}"

sign = ", ".join(
fmt_input(key, var) for key, var in self.inputs.get_vars().items()
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _temporary_renames(**kwargs: Var):
pre: dict[Var, Optional[str]] = {}
try:
for name, arg in kwargs.items():
pre[arg._var_info] = arg._var_info._name
pre[arg] = arg._var_info._name
arg._var_info._rename(name)
yield
finally:
Expand Down
39 changes: 34 additions & 5 deletions src/spox/_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,35 @@ def __rxor__(self, other) -> "Var":
T = TypeVar("T")


@overload
def wrap_vars(var_info: VarInfo) -> Var: ...


@overload
def wrap_vars(var_info: Optional[VarInfo]) -> Optional[Var]: ...


@overload
def wrap_vars(var_info: dict[T, VarInfo]) -> dict[T, Var]: ... # type: ignore[misc]


@overload
def wrap_vars(var_info: Union[Sequence[VarInfo], Iterable[VarInfo]]) -> list[Var]: ...


def wrap_vars(var_info):
if var_info is None:
return None
elif isinstance(var_info, VarInfo):
return Var(var_info)
elif isinstance(var_info, dict):
return {k: wrap_vars(v) for k, v in var_info.items()}
elif isinstance(var_info, (Sequence, Iterable)):
return [wrap_vars(v) for v in var_info]
else:
raise ValueError("Unsupported type for wrap_vars")


@overload
def unwrap_vars(var: Var) -> VarInfo: ...

Expand All @@ -320,7 +349,7 @@ def unwrap_vars(var: Optional[Var]) -> Optional[VarInfo]: ...


@overload
def unwrap_vars(var: dict[T, Var]) -> dict[T, VarInfo]: ...
def unwrap_vars(var: dict[T, Var]) -> dict[T, VarInfo]: ... # type: ignore[misc]


@overload
Expand Down Expand Up @@ -348,6 +377,10 @@ def get_value(var: Var) -> Optional[_value_prop.PropValue]: ...
def get_value(var: Optional[Var]) -> Optional[_value_prop.PropValue]: ...


@overload
def get_value(var: dict[T, Var]) -> dict[T, Optional[_value_prop.PropValue]]: ... # type: ignore[misc]


@overload
def get_value(
var: Union[Sequence[Var], Iterable[Var]],
Expand All @@ -356,10 +389,6 @@ def get_value(
]: ...


@overload
def get_value(var: dict[T, Var]) -> dict[T, Optional[_value_prop.PropValue]]: ...


def get_value(var):
if var is None:
return None
Expand Down
6 changes: 3 additions & 3 deletions src/spox/opset/ai/onnx/ml/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def linear_classifier(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -1508,7 +1508,7 @@ def svmclassifier(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -1833,7 +1833,7 @@ def tree_ensemble_classifier(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down
22 changes: 11 additions & 11 deletions src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -4687,7 +4687,7 @@ def batch_normalization(
input_mean=get_value(input_mean),
input_var=get_value(input_var),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -6383,7 +6383,7 @@ def dropout(
ratio=get_value(ratio),
training_mode=get_value(training_mode),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -6464,7 +6464,7 @@ def dynamic_quantize_linear(
.get_output_vars(
x=get_value(x),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -7113,7 +7113,7 @@ def gru(
sequence_lens=get_value(sequence_lens),
initial_h=get_value(initial_h),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -8707,7 +8707,7 @@ def lstm(
initial_c=get_value(initial_c),
P=get_value(P),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -8812,7 +8812,7 @@ def layer_normalization(
Scale=get_value(Scale),
B=get_value(B),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -9678,7 +9678,7 @@ def max_pool(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -11616,7 +11616,7 @@ def rnn(
sequence_lens=get_value(sequence_lens),
initial_h=get_value(initial_h),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -14678,7 +14678,7 @@ def softmax_cross_entropy_loss(
labels=get_value(labels),
weights=get_value(weights),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -15571,7 +15571,7 @@ def top_k(
X=get_value(X),
K=get_value(K),
)
._unpack_to_any()
.values()
)


Expand Down Expand Up @@ -15869,7 +15869,7 @@ def unique(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down
2 changes: 1 addition & 1 deletion src/spox/opset/ai/onnx/v20.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ def string_split(
.get_output_vars(
X=get_value(X),
)
._unpack_to_any()
.values()
)


Expand Down
16 changes: 8 additions & 8 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs:
attrs["shift_outer"], outer_name="shift_outer", name="value_float"
) # type: ignore
)
x = inputs.X
return self.Outputs(op.add(op.mul(a, x), b))
x = Var(inputs.X)
return self.Outputs(op.add(op.mul(a, x), b)._var_info)

def linear_inner(
x: Var, a: Union[float, _Ref[float]], b: Union[float, _Ref[float]]
Expand Down Expand Up @@ -98,10 +98,10 @@ class Outputs(BaseOutputs):
def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs:
return self.Outputs(
linear(
inputs.X,
Var(inputs.X),
_Ref(attrs["slope1"], outer_name="slope1", name="slope_outer"),
_Ref(attrs["shift1"], outer_name="shift1", name="shift_outer"),
)
)._var_info
)

def linear_inner(
Expand Down Expand Up @@ -142,7 +142,7 @@ class Outputs(BaseOutputs):
outputs: Outputs

def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs:
x = inputs.X
x = Var(inputs.X)
a = op.mul(
linear(
x,
Expand All @@ -165,7 +165,7 @@ def constructor(self, attrs: dict[str, Attr], inputs: Inputs) -> Outputs:
),
)
y = op.add(a, b)
return self.Outputs(y)
return self.Outputs(y._var_info)

def cubic_inner(x: Var, a3: float, a2: float, a1: float, a0: float) -> Var:
return CubicFunction(
Expand All @@ -175,8 +175,8 @@ def cubic_inner(x: Var, a3: float, a2: float, a1: float, a0: float) -> Var:
a1=AttrFloat32(a1, name="a1"),
a0=AttrFloat32(a0, name="a0"),
),
CubicFunction.Inputs(X=x),
).outputs.Y
CubicFunction.Inputs(X=x._var_info),
).get_output_vars()["Y"]

return cubic_inner

Expand Down
3 changes: 2 additions & 1 deletion tests/test_value_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from spox import Var, _type_system
from spox._graph import arguments, results
from spox._shape import Shape
from spox._var import VarInfo
from spox._value_prop import ORTValue, PropValue


Expand All @@ -27,7 +28,7 @@ def value_prop_backend(request):

def dummy_var(typ=None, value=None):
"""Function for creating a ``var`` without an operator but with a type and value."""
return Var(None, typ, value) # type: ignore
return Var(VarInfo(None, typ), value) # type: ignore


def assert_equal_value(var: Var, expected: ORTValue):
Expand Down
2 changes: 1 addition & 1 deletion tools/templates/construct.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ endfor %}
){%
if schema.outputs | length <= 1
%}["{{ schema.outputs[0].name }}"]{%
else %}._unpack_to_any(){%
else %}.values(){%
endif %}

0 comments on commit bfaed79

Please sign in to comment.