Skip to content

Commit

Permalink
reduce on wmma (tinygrad#8016)
Browse files Browse the repository at this point in the history
  • Loading branch information
ignaciosica authored Dec 3, 2024
1 parent dacb1ff commit f14dd14
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def st(self) -> Optional[ShapeTracker]:
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
# all other ops have a contiguous shapetracker
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
Expand Down

0 comments on commit f14dd14

Please sign in to comment.