Skip to content

Commit

Permalink
import numpy as np (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau authored Jul 28, 2024
1 parent 07a0e2f commit 8f6053d
Show file tree
Hide file tree
Showing 23 changed files with 291 additions and 318 deletions.
4 changes: 2 additions & 2 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional

import numpy
import numpy as np
import onnx
import onnx.version_converter

Expand Down Expand Up @@ -43,7 +43,7 @@ def adapt_node(
initializers = [
from_array(var._value, name)
for name, var in node.inputs.get_vars().items()
if isinstance(var._value, numpy.ndarray)
if isinstance(var._value, np.ndarray)
]
except ValueError:
return None
Expand Down
6 changes: 3 additions & 3 deletions src/spox/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TypeVar,
)

import numpy
import numpy as np
import onnx

from . import _function
Expand Down Expand Up @@ -63,7 +63,7 @@ class BuildResult:
results: Tuple[Var, ...]
opset_req: Set[Tuple[str, int]]
functions: Tuple["_function.Function", ...]
initializers: Dict[Var, numpy.ndarray]
initializers: Dict[Var, np.ndarray]


class Builder:
Expand Down Expand Up @@ -433,7 +433,7 @@ def compile_graph(
# A bunch of model metadata we're collecting
opset_req: Set[Tuple[str, int]] = set()
functions: List[_function.Function] = []
initializers: Dict[Var, numpy.ndarray] = {}
initializers: Dict[Var, np.ndarray] = {}

# Add arguments to our scope
for arg in self.arguments_of[graph]:
Expand Down
14 changes: 7 additions & 7 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, replace
from typing import Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union

import numpy
import numpy as np
import onnx
import onnx.shape_inference

Expand All @@ -21,7 +21,7 @@
from ._var import Var


def arguments_dict(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Dict[str, Var]:
def arguments_dict(**kwargs: Optional[Union[Type, np.ndarray]]) -> Dict[str, Var]:
"""
Parameters
----------
Expand All @@ -47,7 +47,7 @@ def arguments_dict(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Dict[str,
),
BaseInputs(),
).outputs.arg
elif isinstance(info, numpy.ndarray):
elif isinstance(info, np.ndarray):
ty = Tensor(info.dtype, info.shape)
result[name] = Argument(
Argument.Attributes(
Expand All @@ -62,13 +62,13 @@ def arguments_dict(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Dict[str,
return result


def arguments(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Tuple[Var, ...]:
def arguments(**kwargs: Optional[Union[Type, np.ndarray]]) -> Tuple[Var, ...]:
"""This function is a shorthand for a respective call to ``arguments_dict``, unpacking the Vars from the dict."""
return tuple(arguments_dict(**kwargs).values())


def enum_arguments(
*infos: Union[Type, numpy.ndarray], prefix: str = "in"
*infos: Union[Type, np.ndarray], prefix: str = "in"
) -> Tuple[Var, ...]:
"""
Convenience function for creating an enumeration of arguments, prefixed with ``prefix``.
Expand All @@ -91,7 +91,7 @@ def enum_arguments(
return arguments(**{f"{prefix}{i}": info for i, info in enumerate(infos)})


def initializer(arr: numpy.ndarray) -> Var:
def initializer(arr: np.ndarray) -> Var:
"""
Create a single initializer (frozen argument) with a given array value.
Expand Down Expand Up @@ -260,7 +260,7 @@ def _get_opset_req(self) -> Set[Tuple[str, int]]:
self._extra_opset_req if self._extra_opset_req is not None else set()
)

def _get_initializers_by_name(self) -> Dict[str, numpy.ndarray]:
def _get_initializers_by_name(self) -> Dict[str, np.ndarray]:
"""Internal function for accessing the initializers by name in the build."""
return {
self._get_build_result().scope.var[var]: init
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Callable, Dict, Tuple

import numpy
import numpy as np
import onnx
import onnx.reference
import onnx.shape_inference
Expand Down Expand Up @@ -99,7 +99,7 @@ def out_value_info(curr_key, curr_var):
initializers = [
from_array(var._value.value, key)
for key, var in self.inputs.get_vars().items()
if var._value and isinstance(var._value.value, numpy.ndarray)
if var._value and isinstance(var._value.value, np.ndarray)
]
# Graph and model
graph = onnx.helper.make_graph(
Expand Down
22 changes: 11 additions & 11 deletions src/spox/_value_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Dict, List, Union

import numpy
import numpy as np
import onnx
import onnx.reference

Expand All @@ -20,9 +20,9 @@
- PropValue -> Optional, Some (has value)
- None -> Optional, Nothing (no value)
"""
PropValueType = Union[numpy.ndarray, List["PropValue"], "PropValue", None]
ORTValue = Union[numpy.ndarray, list, None]
RefValue = Union[numpy.ndarray, list, float, None]
PropValueType = Union[np.ndarray, List["PropValue"], "PropValue", None]
ORTValue = Union[np.ndarray, list, None]
RefValue = Union[np.ndarray, list, float, None]

VALUE_PROP_STRICT_CHECK: bool = False

Expand Down Expand Up @@ -56,12 +56,12 @@ def __post_init__(self):
# platform-dependent dtype - such as ulonglong.
# Though very similar, it does not compare equal to the usual sized dtype.
# (for example ulonglong is not uint64)
if isinstance(self.value, numpy.ndarray) and numpy.issubdtype(
self.value.dtype, numpy.number
if isinstance(self.value, np.ndarray) and np.issubdtype(
self.value.dtype, np.number
):
# We normalize by reconstructing the dtype through its name
object.__setattr__(
self, "value", self.value.astype(numpy.dtype(self.value.dtype.name))
self, "value", self.value.astype(np.dtype(self.value.dtype.name))
)

if VALUE_PROP_STRICT_CHECK and not self.check():
Expand All @@ -76,7 +76,7 @@ def __str__(self):
def check(self) -> bool:
if isinstance(self.type, Tensor):
return (
isinstance(self.value, numpy.ndarray)
isinstance(self.value, np.ndarray)
and self.value.dtype.type is self.type.dtype.type
and Shape.from_simple(self.value.shape) <= self.type._shape
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def from_ref_value(cls, typ: Type, value: RefValue) -> "PropValue":
elem_type = typ.unwrap_sequence().elem_type
return cls(typ, [cls.from_ref_value(elem_type, elem) for elem in value])
else: # otherwise must have Tensor (sometimes this is just a scalar)
return cls(typ, numpy.array(value))
return cls(typ, np.array(value))
# No fail branch because representations of Tensor are inconsistent

@classmethod
Expand All @@ -122,9 +122,9 @@ def from_ort_value(cls, typ: Type, value: ORTValue) -> "PropValue":
elif isinstance(value, list): # Sequence
elem_type = typ.unwrap_sequence().elem_type
return cls(typ, [cls.from_ort_value(elem_type, elem) for elem in value])
elif isinstance(value, numpy.ndarray): # Tensor
elif isinstance(value, np.ndarray): # Tensor
# Normalise the dtype in case we got an alias (like longlong)
if value.dtype == numpy.dtype(object):
if value.dtype == np.dtype(object):
value = value.astype(str)
return cls(typ, value)
raise TypeError(f"No handler for ORT value: {value}")
Expand Down
10 changes: 5 additions & 5 deletions src/spox/_var.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
from typing import Any, Callable, ClassVar, Optional, TypeVar, Union

import numpy
import numpy as np

from . import _type_system, _value_prop

Expand Down Expand Up @@ -195,11 +195,11 @@ def __rxor__(self, other) -> "Var":


def result_type(
*types: Union[Var, numpy.generic, int, float],
) -> typing.Type[numpy.generic]:
*types: Union[Var, np.generic, int, float],
) -> typing.Type[np.generic]:
"""Promote type for all given element types/values using ``np.result_type``."""
return numpy.dtype(
numpy.result_type(
return np.dtype(
np.result_type(
*(
typ.unwrap_tensor().dtype if isinstance(typ, Var) else typ
for typ in types
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import Dict, Optional

import numpy
import numpy as np
import onnxruntime
import pytest

Expand Down Expand Up @@ -63,12 +63,12 @@ def assert_close(given, expected, rtol=1e-7):
else:
if isinstance(given, list):
for subarray in given:
numpy.testing.assert_allclose(
given, numpy.array(expected, dtype=subarray.dtype), rtol=rtol
np.testing.assert_allclose(
given, np.array(expected, dtype=subarray.dtype), rtol=rtol
)
else:
numpy.testing.assert_allclose(
given, numpy.array(expected, dtype=given.dtype), rtol=rtol
np.testing.assert_allclose(
given, np.array(expected, dtype=given.dtype), rtol=rtol
)


Expand Down
24 changes: 12 additions & 12 deletions tests/full/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple

import numpy
import numpy as np
import pytest

import spox.opset.ai.onnx.v17 as op
Expand All @@ -16,17 +16,17 @@ class Extras:

def false(self):
if self._false is None:
self._false = op.const(numpy.array(False))
self._false = op.const(np.array(False))
return self._false

def true(self):
if self._true is None:
self._true = op.const(numpy.array(True))
self._true = op.const(np.array(True))
return self._true

def empty_i64(self):
if self._empty_i64 is None:
self._empty_i64 = op.const(numpy.array([], dtype=numpy.int64))
self._empty_i64 = op.const(np.array([], dtype=np.int64))
return self._empty_i64

@staticmethod
Expand All @@ -52,10 +52,10 @@ def pop(var: Var) -> Var:

@staticmethod
def at(t: Var, j: Var) -> Var:
j = op.reshape(j, op.const(numpy.array([1], dtype=numpy.int64)))
j = op.reshape(j, op.const(np.array([1], dtype=np.int64)))
return op.reshape(
op.slice(t, j, op.add(j, op.const(1))),
op.const(numpy.array([], dtype=numpy.int64)),
op.const(np.array([], dtype=np.int64)),
)

@staticmethod
Expand Down Expand Up @@ -89,9 +89,9 @@ def bracket_matcher_step(
op.reshape(op.size(xs), op.const([1])),
None,
[
op.sequence_empty(dtype=numpy.int64),
op.sequence_empty(dtype=numpy.int64),
op.const(numpy.array(True)),
op.sequence_empty(dtype=np.int64),
op.sequence_empty(dtype=np.int64),
op.const(np.array(True)),
],
body=bracket_matcher_step,
)
Expand All @@ -104,13 +104,13 @@ def onehot(ext, n: Var, i: Var) -> Var:
return op.pad(op.const([1]), ext.scalars(i, op.sub(op.sub(n, i), op.const(1))))

def set_to(ext, t: Var, j: Var, x: Var) -> Var:
return op.where(op.cast(ext.onehot(op.size(t), j), to=numpy.bool_), x, t)
return op.where(op.cast(ext.onehot(op.size(t), j), to=np.bool_), x, t)

@staticmethod
def is_token(var: Var, token: str) -> Var:
return op.equal(
op.cast(var, to=numpy.int32),
op.cast(op.const(numpy.uint8(ord(token))), to=numpy.int32),
op.cast(var, to=np.int32),
op.cast(op.const(np.uint8(ord(token))), to=np.int32),
)

def flat_concat(ext, s: Var) -> Var:
Expand Down
3 changes: 1 addition & 2 deletions tests/future/test_var_operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import operator

import numpy
import numpy as np
import pytest

Expand Down Expand Up @@ -125,4 +124,4 @@ def test_var_operator_promotion_like_numpy(bin_op, lhs, rhs):
assert (
spox_value.dtype == numpy_value.dtype
), f"{lhs!r}: {type(lhs)} | {rhs!r}: {type(rhs)}"
assert numpy.isclose(spox_value, numpy_value).all()
assert np.isclose(spox_value, numpy_value).all()
Loading

0 comments on commit 8f6053d

Please sign in to comment.