From 39ab67e9ef7219cae9aa69bf95feee0f50ba7cec Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 16 Oct 2024 20:26:58 +0300 Subject: [PATCH] beam capture and replay in fuzz (#7099) * beam capture and reply in fuzz * clean a bit --- test/external/fuzz_linearizer.py | 30 +++++++++++++++++++++++++----- tinygrad/engine/search.py | 5 +++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index fab8ec6f65863..57d74a15cd3fc 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -140,7 +140,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No return ("PASS", rawbufs, var_vals, ground_truth, run_state) -def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): +def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None): SEED = getenv("SEED", 42) random.seed(SEED) np.random.seed(SEED) @@ -162,10 +162,18 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): print("skipping simple kernel") return failures - for depth in range(getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)): + test_depth = 1 if opts_list is not None else getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10) + for depth in range(test_depth): next_lins = [] for lin in last_lins: - actions = get_kernel_actions(lin, include_0=False) + if opts_list is None: actions = get_kernel_actions(lin, include_0=False) + else: + actions = {} + for oi,opts in enumerate(opts_list): + lin2 = lin.copy() + for o in opts: lin2.apply_opt(o) + actions[oi] = lin2 + if not actions: continue if depth == 0 and getenv("FUZZ_REQUIRE_TC", 0): tc_acts = {i: k for k in actions.values() if k.applied_opts[0].op == OptOps.TC} @@ -174,7 +182,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): test_lins = list(actions.values()) if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions") - else: test_lins = [random.choice(test_lins)] + elif opts_list is None: test_lins = [random.choice(test_lins)] for test_lin in test_lins: if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}") @@ -230,12 +238,14 @@ def _is_simple(lin: Kernel) -> bool: parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized") parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line") + parser.add_argument("--beamreplay", type=str, default=None, help="replay asts and opts got from beam with CAPTURE_BEAM") parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line") parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels") parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison") parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison") args = parser.parse_args() + opts_list = None if args.ast is not None: print("loaded AST from CLI") ast_strs = [args.ast] @@ -243,6 +253,16 @@ def _is_simple(lin: Kernel) -> bool: print(f"loading ASTs from file '{args.file}'") with open(args.file, 'r') as file: ast_strs = file.readlines() + elif args.beamreplay is not None: + print(f"loading BEAM replay from file '{args.beamreplay}'") + with open(args.beamreplay, 'r') as file: fdata = file.readlines() + ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata] + + # dedup ast_strs and opts_list + dct = defaultdict(list) + for i in range(len(ast_strs)): dct[ast_strs[i]].append(eval(opts_list[i])) + ast_strs_items = list(dct.keys()) + opts_list = [dct[c] for c in ast_strs_items] elif args.logfile is not None: print(f"loading ASTs from LOGKERNS file '{args.file}'") with open(args.logfile, 'r') as file: @@ -273,7 +293,7 @@ def _is_simple(lin: Kernel) -> bool: with Timing(f"tested ast {i}: "): tested += 1 - fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol) + fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol, opts_list=(opts_list[i] if opts_list else None)) if fuzz_failures: failed_ids.append(i) for k, v in fuzz_failures.items(): for f in v: diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 1d2dea66f25ab..c5061e31061a7 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -117,7 +117,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]: except KernelOptError: pass return acted_lins -beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") +beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "") def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel: global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} @@ -154,7 +154,8 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, # filter out kernels that use 1000x more compute than the smallest least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops) if least_compute_ops*1000 < this_compute_ops: continue - #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault + if len(CAPTURE_BEAM) > 0: + with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n") seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) except RuntimeError: continue # for runtime issues