Skip to content

Commit

Permalink
No pyint again (tinygrad#7156)
Browse files Browse the repository at this point in the history
* Revert "bring back pyint (tinygrad#7150)"

This reverts commit 37e83ca.

* remove truncate in const folding

* truncate_output=False
  • Loading branch information
chenyuxyz authored Oct 19, 2024
1 parent 30989fb commit f511ad9
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 26 deletions.
1 change: 0 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.pyint and device != "PYTHON": return False
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
Expand Down
6 changes: 3 additions & 3 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")

core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint'])
core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]

def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
# dont cast internal dtypes
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint']
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]

def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
Expand Down Expand Up @@ -806,7 +806,7 @@ def test_abs_diff(self, dt):

class TestDtypeUsage(unittest.TestCase):
def test_max_w_alu(self):
for d in dtype_ints:
for d in dtypes.ints:
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()

Expand Down
2 changes: 1 addition & 1 deletion test/test_linearizer_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def test_failure_54(self):
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD", "METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"])

if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)

# test no truncate
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)

class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
Expand Down
10 changes: 5 additions & 5 deletions tinygrad/codegen/lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
if reverse: dims = dims[::-1]
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
if limited != dims:
ret = []
# cast for mypy, get_contraction won't be None
Expand Down Expand Up @@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]

# reduce loops
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]

# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))

# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))

return IndexContext(idxs, ridxs)

Expand Down
6 changes: 0 additions & 6 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,6 @@ def find_gate(x:UOp) -> Optional[UOp]:
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
])

no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])

# *** uop graph ***

linearize_cnt = 0
Expand All @@ -559,9 +556,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
acc_number = 0
sink = graph_rewrite(sink, sym)

# rewrite pyint to int32
sink = graph_rewrite(sink, no_pyint)

# expand
linearize_cnt += 1
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
Expand Down
3 changes: 1 addition & 2 deletions tinygrad/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def finfo(dtype:DType) -> Tuple[int, int]: # (exponent, mantissa)
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
void: Final[DType] = DType(-1, 0, "void", None, 1)
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
Expand Down Expand Up @@ -116,7 +115,7 @@ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dty

floats = (float16, bfloat16, float32, float64)
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64, pyint)
sints = (int8, int16, int32, int64)
ints = uints + sints

if (env_default_float := getenv("DEFAULT_FLOAT", "")):
Expand Down
10 changes: 4 additions & 6 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ def wfxn(*args):
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}

def exec_alu(op:Op, dtype:DType, operands):
def exec_alu(op:Op, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu

# ***** uop helpers *****

Expand Down Expand Up @@ -691,9 +692,6 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(UOps.SPECIAL, src=()), lambda: True),

# no pyint allowed here!
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),

# TODO: confirm the args of both of these are shapetrackers
(UPat(UOps.VIEW, src=()), lambda: True),
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
Expand Down Expand Up @@ -906,7 +904,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# ALU min==max -> CONST (slow!)
(UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
offs -= here * stride
return result

def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x

@dataclass(frozen=True)
class View:
Expand All @@ -93,7 +93,7 @@ class View:
contiguous:bool

def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
iexpr = variable_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
Expand Down

0 comments on commit f511ad9

Please sign in to comment.