Skip to content

Commit

Permalink
fix typing for test_ops
Browse files Browse the repository at this point in the history
mostly passed `TYPED=1 python3 -m pytest -n=auto test/test_ops.py`.

need to get rid of dependency on tensorflow_addons because it requires a very old version of typeguard to add typeguard in CI.
  • Loading branch information
chenyuxyz committed Sep 15, 2024
1 parent cd90092 commit 65c4b0d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_negative_dims(self):
with self.assertRaises(ValueError): method((2, -3, 0))

def test_negative_dims_full(self):
with self.assertRaises(ValueError): Tensor.full(-3, 2)
with self.assertRaises(ValueError): Tensor.full((-3,), 2)
with self.assertRaises(ValueError): Tensor.full((2, -3), 4)
with self.assertRaises(ValueError): Tensor.full((2, -3, 0), 4)

Expand Down Expand Up @@ -2110,7 +2110,7 @@ def test_cross_entropy_reductions(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
lambda x,y: x.cross_entropy(y, reduction=r))
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)

def test_cross_entropy_smoothing(self):
for ls in (0., 0.3, 0.7, 1.):
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MetaOps(FastEnum):
class MathTrait:
# required to implement
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
def const_like(self, b:ConstType|Variable): raise NotImplementedError
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): raise NotImplementedError

# great functions you get!
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
Expand Down Expand Up @@ -384,7 +384,7 @@ def st_arg(self) -> ShapeTracker:
return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b)
def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:Union[Tuple[int, ...], int]):
Expand Down Expand Up @@ -671,7 +671,7 @@ def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtyp
@classmethod
def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src)

def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b)
def alu(self, arg, *src:UPat):
asrc = (self,)+src
return type(self)(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype,
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,7 +1778,7 @@ def parse_formula(formula: str, *operands: Tensor):
out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))

xs:Tuple[Tensor] = argfix(*raw_xs)
xs:Tuple[Tensor, ...] = argfix(*raw_xs)
inputs_str, output = parse_formula(formula.replace(" ", ""), *xs)
inputs = inputs_str.split(",")
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
Expand Down Expand Up @@ -1826,7 +1826,7 @@ def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilat
xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))))
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])

def _padding2d(self, padding:Union[int, Tuple[int, ...]], dims:int) -> Sequence[int]:
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])

# NOTE: these work for more than 2D
Expand Down

0 comments on commit 65c4b0d

Please sign in to comment.