Skip to content

Commit

Permalink
Hacky fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 6, 2024
1 parent 340d4c2 commit 3d77a87
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 317 deletions.
7 changes: 3 additions & 4 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from ._debug import STORE_TRACEBACK
from ._exceptions import InferenceWarning
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, VarFieldKind
from ._type_system import Type
from ._value_prop import PropValueType
from ._type_system import PropDict, Type
from ._var import VarInfo

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(
out_variadic: Optional[int] = None,
infer_types: bool = True,
validate: bool = True,
input_prop_values={},
input_prop_values: PropDict = {},
**kwargs,
):
"""
Expand Down Expand Up @@ -207,7 +206,7 @@ def pre_init(self, **kwargs):
def post_init(self, **kwargs):
"""Post-initialization hook. Called at the end of ``__init__`` after other default fields are set."""

def propagate_values(self, input_prop_values) -> dict[str, PropValueType]:
def propagate_values(self, input_prop_values: PropDict) -> PropDict:
"""
Propagate values from inputs, and, if possible, compute values for outputs as well.
This method is used to implement ONNX partial data propagation - for example so that
Expand Down
16 changes: 9 additions & 7 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._schemas import SCHEMAS
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Optional, Sequence, Tensor, Type
from ._type_system import Optional, PropDict, Sequence, Tensor, Type
from ._utils import from_array
from ._value_prop import PropValue, PropValueType

Expand Down Expand Up @@ -54,7 +54,7 @@ def to_singleton_onnx_model(
*,
dummy_outputs: bool = True,
with_dummy_subgraphs: bool = True,
prop_values={},
input_prop_values: PropDict = {},
) -> tuple[onnx.ModelProto, Scope]:
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Expand Down Expand Up @@ -107,15 +107,15 @@ def out_value_info(curr_key, curr_var):
# TODO: fix this
initializers_from_array = [
from_array(prop.value, name) # type: ignore
for name, prop in prop_values.items()
for name, prop in input_prop_values.items()
if isinstance(prop, PropValue)
and prop.value is not None
and not isinstance(prop.type, Sequence)
]

initializers_from_sequence = [
from_array(prop.value, f"{name}_{i}") # type: ignore
for name, prop_list in prop_values.items()
for name, prop_list in input_prop_values.items()
if isinstance(prop_list, list)
for i, prop in enumerate(prop_list)
if prop is not None and not isinstance(prop.value, Iterable)
Expand Down Expand Up @@ -149,7 +149,7 @@ def infer_output_types_onnx(self, input_prop_values={}) -> dict[str, Type]:
if any(var.type is None for var in self.inputs.get_var_infos().values()):
return {}

model, _ = self.to_singleton_onnx_model(prop_values=input_prop_values)
model, _ = self.to_singleton_onnx_model(input_prop_values=input_prop_values)

# Attempt to do shape inference - if an error is caught, we extend the traceback a bit
try:
Expand All @@ -173,7 +173,9 @@ def infer_output_types_onnx(self, input_prop_values={}) -> dict[str, Type]:
for key, type_ in results.items()
}

def propagate_values_onnx(self, input_prop_values) -> dict[str, PropValueType]:
def propagate_values_onnx(
self, input_prop_values: PropDict
) -> dict[str, PropValueType]:
"""Perform value propagation by evaluating singleton model.
The backend used for the propagation can be configured with the `spox._standard.ValuePropBackend` variable.
Expand All @@ -188,7 +190,7 @@ def propagate_values_onnx(self, input_prop_values) -> dict[str, PropValueType]:
# Cannot do propagation with subgraphs implicitly for performance - should be reimplemented
return {}
model, scope = self.to_singleton_onnx_model(
with_dummy_subgraphs=False, prop_values=input_prop_values
with_dummy_subgraphs=False, input_prop_values=input_prop_values
)
wrap_feed, run, unwrap_feed = _value_prop.get_backend_calls()
input_feed = {
Expand Down
5 changes: 4 additions & 1 deletion src/spox/_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass
from typing import TypeVar
from typing import Any, TypeVar

import numpy as np
import numpy.typing as npt
Expand All @@ -14,6 +14,9 @@
T = TypeVar("T")
S = TypeVar("S")

# TODO: Fix typing
PropDict = dict[str, Any]


@dataclass(frozen=True)
class Type:
Expand Down
42 changes: 20 additions & 22 deletions src/spox/opset/ai/onnx/ml/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# ruff: noqa: E741 -- Allow ambiguous variable name
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import (
Optional,
)
from typing import Optional

import numpy as np

Expand All @@ -22,7 +20,7 @@
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs
from spox._node import OpType
from spox._standard import InferenceError, StandardNode
from spox._type_system import Tensor, Type
from spox._type_system import PropDict, Tensor, Type
from spox._var import Var, VarInfo, get_value, unwrap_vars


Expand Down Expand Up @@ -662,7 +660,7 @@ def array_feature_extractor(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
"Y": get_value(Y),
}
Expand Down Expand Up @@ -711,7 +709,7 @@ def binarizer(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -776,7 +774,7 @@ def cast_map(
- T1: `map(int64,tensor(float))`, `map(int64,tensor(string))`
- T2: `tensor(float)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -851,7 +849,7 @@ def category_mapper(
- T1: `tensor(int64)`, `tensor(string)`
- T2: `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -922,7 +920,7 @@ def dict_vectorizer(
- T1: `map(int64,tensor(double))`, `map(int64,tensor(float))`, `map(int64,tensor(string))`, `map(string,tensor(double))`, `map(string,tensor(float))`, `map(string,tensor(int64))`
- T2: `tensor(double)`, `tensor(float)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -979,7 +977,7 @@ def feature_vectorizer(
Type constraints:
- T1: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1054,7 +1052,7 @@ def imputer(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1163,7 +1161,7 @@ def label_encoder(
- T1: `tensor(float)`, `tensor(int64)`, `tensor(string)`
- T2: `tensor(float)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1246,7 +1244,7 @@ def linear_classifier(
- T1: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
- T2: `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1322,7 +1320,7 @@ def linear_regressor(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1379,7 +1377,7 @@ def normalizer(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1446,7 +1444,7 @@ def one_hot_encoder(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1548,7 +1546,7 @@ def svmclassifier(
- T1: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
- T2: `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1645,7 +1643,7 @@ def svmregressor(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1711,7 +1709,7 @@ def scaler(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -1864,7 +1862,7 @@ def tree_ensemble_classifier(
- T1: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
- T2: `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -2057,7 +2055,7 @@ def tree_ensemble_regressor(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(int32)`, `tensor(int64)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down Expand Up @@ -2158,7 +2156,7 @@ def zip_map(
Type constraints:
- T: `seq(map(int64,tensor(float)))`, `seq(map(string,tensor(float)))`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down
7 changes: 3 additions & 4 deletions src/spox/opset/ai/onnx/ml/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# ruff: noqa: E741 -- Allow ambiguous variable name
from collections.abc import Iterable
from dataclasses import dataclass
from typing import (
Optional,
)
from typing import Optional

import numpy as np

Expand All @@ -22,6 +20,7 @@
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_system import PropDict
from spox._var import Var, VarInfo, get_value, unwrap_vars
from spox.opset.ai.onnx.ml.v3 import (
_ArrayFeatureExtractor,
Expand Down Expand Up @@ -191,7 +190,7 @@ def label_encoder(
- T1: `tensor(double)`, `tensor(float)`, `tensor(int16)`, `tensor(int32)`, `tensor(int64)`, `tensor(string)`
- T2: `tensor(double)`, `tensor(float)`, `tensor(int16)`, `tensor(int32)`, `tensor(int64)`, `tensor(string)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down
7 changes: 3 additions & 4 deletions src/spox/opset/ai/onnx/ml/v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# ruff: noqa: E741 -- Allow ambiguous variable name
from collections.abc import Iterable
from dataclasses import dataclass
from typing import (
Optional,
)
from typing import Optional

import numpy as np

Expand All @@ -18,6 +16,7 @@
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_system import PropDict
from spox._var import Var, VarInfo, get_value, unwrap_vars
from spox.opset.ai.onnx.ml.v4 import (
_ArrayFeatureExtractor,
Expand Down Expand Up @@ -224,7 +223,7 @@ def tree_ensemble(
Type constraints:
- T: `tensor(double)`, `tensor(float)`, `tensor(float16)`
"""
input_prop_values = {
input_prop_values: PropDict = {
"X": get_value(X),
}
return (
Expand Down
Loading

0 comments on commit 3d77a87

Please sign in to comment.