From 1ec6bd512530a4f9d5476ac9ebeeb3cdc4180d9d Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 18 Sep 2024 04:46:41 -0400 Subject: [PATCH] restructure simplify_valid_image_load [run_process_replay] (#6581) * restructure simplify_valid_image_load [run_process_replay] separated parsing valid / idx and simplification * space * type --- tinygrad/codegen/uopgraph.py | 56 +++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0083655624d35..7525136e014ae 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable import functools, itertools, heapq, math, operator from collections import defaultdict -from tinygrad.dtype import dtypes, PtrDType, ImageDType +from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element from tinygrad.ops import UPat, PatternMatcher, graph_rewrite from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same @@ -174,41 +174,51 @@ def simplify_valid_image_load(load:UOp, buf:UOp): new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx)) - # We want to simplify expressions like (X*c+d)%m in the idx, with optional *c and +d. m is the total length of the row. - # If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid. + # first, parse valid into {expr: ((lower_bound, statement), (upper_bound, statement))} + bounds:DefaultDict[UOp, List[Optional[Tuple[ConstType, UOp]]]] = defaultdict(lambda: [None, None]) + for stmt in _get_chain(valid, BinaryOps.AND): + if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST: + if (s:=stmt.src[0]).op is UOps.ALU and s.arg is BinaryOps.MUL and s.src[1].op is UOps.CONST and s.src[1].arg == -1: + bounds[s.src[0]][0] = (-stmt.src[1].arg+1, stmt) + else: bounds[s][1] = (stmt.src[1].arg-1, stmt) + + for v in bounds.values(): + # some expr has lower bound > upper bound -> valid is an empty set + if v[0] is not None and v[1] is not None and v[0][0] > v[1][0]: + return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False))) + + # next, parse idx by the form ((X*c+d)%m, ((X*c+d)//m+e)) + # parse m m = mod.src[1].arg if (mod:=idx.src[0]).op is UOps.ALU and mod.arg is BinaryOps.MOD and mod.src[1].op is UOps.CONST else None if not m or m != buf_dtype.shape[1]: return None + # parse idx.src[0] d = add.src[1].arg if (add:=mod.src[0]).op is UOps.ALU and add.arg is BinaryOps.ADD and add.src[1].op is UOps.CONST else 0 mul = add.src[0] if d else add # + d is optional c = mul.src[1].arg if mul.op is UOps.ALU and mul.arg is BinaryOps.MUL and mul.src[1].op is UOps.CONST else 1 X = mul.src[0] if c != 1 else mul # * c is optional + # parse idx.src[1] + e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0 + div = add1.src[0] if e else add1 + m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None + if m_ != m or div.src[0] != add: return None - lower, upper = X.vmin, X.vmax + # from valid, find the bound of X drop_stmt = [] + if X in bounds and (b0:=bounds[X][0]) is not None: + lower = b0[0] + drop_stmt.append(b0[1]) + else: lower = X.vmin + if X in bounds and (b1:=bounds[X][1]) is not None: + upper = b1[0] + drop_stmt.append(b1[1]) + else: upper = X.vmax - for stmt in _get_chain(valid, BinaryOps.AND): - if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST: - if stmt.src[0].key == X.key: # X < c - upper = stmt.src[1].arg-1 - drop_stmt.append(stmt) - elif stmt.src[0].key == (-X).key: # -X < -c -> X > c - lower = -stmt.src[1].arg+1 - drop_stmt.append(stmt) - - # valid is an empty set - if upper < lower: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False))) - + # If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid. new_indx0, new_indx1 = None, None if (L:=(lower * c + d)) // m == (U:=(upper * c + d)) // m: # in the same row if (L % m - c < 0) and (U % m + c >= m): # spans the whole row new_indx0 = graph_rewrite(mul - ((L // m) * m - d), constant_folder) - - # Because (X * c + d) % m spans the whole row, (X * c + d) // m has a fixed value. - # check if idx1 is a div that can be simplified. idx1 = (add // m + e) - e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0 - div = add1.src[0] if e else add1 - m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None - if m_ == m and div.src[0] == add: new_indx1 = idx.src[1].const_like(L // m + e) + new_indx1 = idx.src[1].const_like(L // m + e) if new_indx0 and new_indx1: new_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_indx0, new_indx1))