Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove upcast_in_mid_reduce_axes [pr] #599

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:sel
@property
def shape_len(self) -> int: return len(self.sts[0].shape)

@property
def upcast_in_mid_reduce_axes(self) -> List[int]:
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]

@property
def global_dims(self) -> int: return self.first_reduce-self.local_dims

Expand All @@ -165,7 +161,6 @@ def global_dims(self) -> int: return self.first_reduce-self.local_dims
# cyan -- local dims (warp ones first)
# *** self.first_reduce
# green -- reduce-local dims
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
# red -- reduce loops
# *** self.upcasted
# purple -- reduce upcasted
Expand All @@ -175,8 +170,8 @@ def colors(self) -> List[str]:
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
colors += ["cyan"] * self.local_dims
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
colors += ["green"] * self.group_for_reduces
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
# upcasted dimensions are reduce (magenta) or normal (yellow)
Expand Down Expand Up @@ -446,8 +441,7 @@ def required_optimizations(self) -> Kernel:
if isinstance(self.membufs[0].dtype, ImageDType):
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
if all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
return self

def hand_coded_optimizations(self) -> Kernel:
Expand Down Expand Up @@ -487,7 +481,7 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if buf.src[0].dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
Expand Down
Loading