From eae682a18a1598474695432bcd11f9347a41b6c3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 16 May 2024 12:06:36 -0700 Subject: [PATCH] filter fake buffers in search --- test/test_search.py | 16 ++++++++++++++-- tinygrad/engine/search.py | 3 ++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_search.py b/test/test_search.py index baac49136e791..db63469ceea1f 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -3,12 +3,15 @@ from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.linearizer import Linearizer from tinygrad.engine.schedule import create_schedule -from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions +from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer -from tinygrad.ops import LoadOps, BufferOps +from tinygrad.ops import LazyOp, LoadOps, BufferOps, ReduceOps, BinaryOps, MemBuffer, ConstBuffer from tinygrad.tensor import Tensor +from tinygrad.dtype import dtypes from tinygrad.helpers import Context from tinygrad.engine.realize import capturing +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View class TestTimeLinearizer(unittest.TestCase): @unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607") @@ -66,5 +69,14 @@ def test_get_linearizer_actions(self): if Opt(OptOps.GROUPTOP, 0, 0) in actions: assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP" + def test_filter_global_buffer(self): + # taken from https://github.com/tinygrad/tinygrad/issues/4612 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + lin = Linearizer(ast) + + bufs = bufs_from_lin(lin) + best_lin = beam_search(lin, bufs, 3) + assert best_lin + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index d8950a16b1966..bd1d328b69d8a 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -42,10 +42,11 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_ try: car = CompiledRunner(p, precompiled=lib) except AssertionError: return [math.inf] * cnt tms = [] + input_bufs = [rawbufs[i] for i,_ in car.p.globals] for _ in range(cnt): if clear_l2: with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize() - tms.append(cast(float, car(rawbufs, var_vals, wait=True))*factor) + tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor) if early_stop is not None and early_stop < tms[-1]: break return tms