Skip to content

Commit

Permalink
Add some proper typing
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Nov 4, 2024
1 parent bfaed79 commit 49e366c
Show file tree
Hide file tree
Showing 18 changed files with 4,374 additions and 3,192 deletions.
87 changes: 84 additions & 3 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,54 @@ class VarFieldKind(enum.Enum):
VARIADIC = 2


class BaseVars:
def _unpack_to_any(self) -> Any:
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.__dict__.values())

def _flatten(self) -> Iterable[tuple[str, Optional[Var]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.__dict__.items():
if value is None or isinstance(value, Var):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def get_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}


class BaseVarsMeta(type):
def __new__(cls, name, bases, namespace):
new_cls = super().__new__(cls, name, bases, namespace)

if bases and "__annotations__" in namespace:
annotations: dict[str, Any] = {}
for name, typ in namespace["__annotations__"].items():
if typ == VarInfo:
annotations[name] = Var
elif typ == Optional[VarInfo]:
annotations[name] = Optional[Var]
elif typ == Sequence[VarInfo]:
annotations[name] = Sequence[Var]

vars_cls = dataclass(
type(
"Vars",
(
BaseVars,
object,
),
{"__annotations__": annotations},
)
) # type: ignore

setattr(new_cls, "Vars", vars_cls)

return new_cls


@dataclass
class BaseVarInfos(BaseFields):
def __post_init__(self):
Expand Down Expand Up @@ -110,12 +158,45 @@ def fully_typed(self) -> bool:


@dataclass
class BaseInputs(BaseVarInfos):
pass
class BaseInputs(BaseVarInfos, metaclass=BaseVarsMeta):
@dataclass
class Vars(BaseVars):
pass

def vars(self, prop_values) -> Vars:
vars_structure: dict[str, Union[Var, Sequence[Var]]] = {}

for field in dataclasses.fields(self):
field_type = self._get_field_type(field)
field_value = getattr(self, field.name)

if field_type == VarFieldKind.SINGLE:
vars_structure[field.name] = Var(field_value, prop_values[field.name])

elif (
field_type == VarFieldKind.OPTIONAL
and prop_values.get(field.name, None) is not None
):
vars_structure[field.name] = Var(field_value, prop_values[field.name])

elif field_type == VarFieldKind.VARIADIC:
vars = []

for i, var_info in enumerate(field_value):
var_value = prop_values.get(f"{field.name}_{i}", None)
vars.append(Var(var_info, var_value))

vars_structure[field.name] = vars

return self.Vars(**vars_structure)


@dataclass
class BaseOutputs(BaseVarInfos):
class BaseOutputs(BaseVarInfos, metaclass=BaseVarsMeta):
@dataclass
class Vars(BaseVars):
pass

def _propagate_vars(
self,
prop_values={},
Expand Down
52 changes: 32 additions & 20 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,32 @@ def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> dict[str, Var
for name, info in kwargs.items():
attr_name = AttrString(value=name, name="dummy")
if isinstance(info, Type):
result[name] = Argument(
Argument.Attributes(
name=attr_name,
type=AttrType(value=info, name="dummy"),
default=None,
),
BaseInputs(),
).get_output_vars()["arg"]
result[name] = (
Argument(
Argument.Attributes(
name=attr_name,
type=AttrType(value=info, name="dummy"),
default=None,
),
BaseInputs(),
)
.get_output_vars()
.arg
)
elif isinstance(info, np.ndarray):
ty = Tensor(info.dtype, info.shape)
result[name] = Argument(
Argument.Attributes(
name=attr_name,
type=AttrType(value=ty, name="dummy"),
default=AttrTensor(value=info, name="dummy"),
),
BaseInputs(),
).get_output_vars()["arg"]
result[name] = (
Argument(
Argument.Attributes(
name=attr_name,
type=AttrType(value=ty, name="dummy"),
default=AttrTensor(value=info, name="dummy"),
),
BaseInputs(),
)
.get_output_vars()
.arg
)
else:
raise TypeError(f"Cannot construct argument from {type(info)}.")
return result
Expand Down Expand Up @@ -110,10 +118,14 @@ def initializer(arr: np.ndarray) -> Var:
-------
Var which is always equal to the respective value provided by `arr`.
"""
return _Initializer(
_Initializer.Attributes(value=AttrTensor(value=arr, name="dummy")),
BaseInputs(),
).get_output_vars()["arg"]
return (
_Initializer(
_Initializer.Attributes(value=AttrTensor(value=arr, name="dummy")),
BaseInputs(),
)
.get_output_vars()
.arg
)


@dataclass(frozen=True, eq=False)
Expand Down
8 changes: 5 additions & 3 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ def intros(*args: Var) -> Sequence[Var]:
Sequence[Var]
Vars of the same value as ``args``, but with a shared dependency.
"""
return _Introduce(
None, _Introduce.Inputs(unwrap_vars(args)), out_variadic=len(args)
).get_output_vars()["outputs"]
return (
_Introduce(None, _Introduce.Inputs(unwrap_vars(args)), out_variadic=len(args))
.get_output_vars()
.outputs
)


def intro(*args: Var) -> Var:
Expand Down
6 changes: 4 additions & 2 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ def inference(self, infer_types: bool = True):
def get_output_vars(self, flatten_variadic=False, **initializers):
# After typing everything, try to get values for outputs
out_values = self.propagate_values(initializers)
return self.outputs._propagate_vars(
out_values, flatten_variadic=flatten_variadic
return type(self.outputs).Vars(
**self.outputs._propagate_vars(
out_values, flatten_variadic=flatten_variadic
)
)

def validate_types(self) -> None:
Expand Down
14 changes: 8 additions & 6 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ def argument(typ: Type) -> Var:
An unnamed argument variable of given type that may be used as
a model input to build a graph.
"""
return _internal_op.Argument(
_internal_op.Argument.Attributes(type=AttrType(typ, "dummy"), default=None)
).get_output_vars()["arg"]
return (
_internal_op.Argument(
_internal_op.Argument.Attributes(type=AttrType(typ, "dummy"), default=None)
)
.get_output_vars()
.arg
)


@contextlib.contextmanager
Expand Down Expand Up @@ -303,9 +307,7 @@ def inline_inner(*args: Var, **kwargs: Var) -> dict[str, Var]:
model=model,
)

return dict(
zip(out_names, node.get_output_vars(flatten_variadic=True).values())
)
return dict(zip(out_names, node.get_output_vars().get_vars().values()))

return inline_inner

Expand Down
Loading

0 comments on commit 49e366c

Please sign in to comment.