diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2f8cbe8484245..beaf3b4416fde 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -275,7 +275,7 @@ def st(self) -> Optional[ShapeTracker]: assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" # all other ops have a contiguous shapetracker from tinygrad.shape.shapetracker import ShapeTracker - return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape) + return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape) @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))