Skip to content

Commit

Permalink
restructure simplify_valid_image_load [run_process_replay] (tinygrad#…
Browse files Browse the repository at this point in the history
…6581)

* restructure simplify_valid_image_load [run_process_replay]

separated parsing valid / idx and simplification

* space

* type
  • Loading branch information
chenyuxyz authored Sep 18, 2024
1 parent d02bb27 commit 1ec6bd5
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
Expand Down Expand Up @@ -174,41 +174,51 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))

# We want to simplify expressions like (X*c+d)%m in the idx, with optional *c and +d. m is the total length of the row.
# If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid.
# first, parse valid into {expr: ((lower_bound, statement), (upper_bound, statement))}
bounds:DefaultDict[UOp, List[Optional[Tuple[ConstType, UOp]]]] = defaultdict(lambda: [None, None])
for stmt in _get_chain(valid, BinaryOps.AND):
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST:
if (s:=stmt.src[0]).op is UOps.ALU and s.arg is BinaryOps.MUL and s.src[1].op is UOps.CONST and s.src[1].arg == -1:
bounds[s.src[0]][0] = (-stmt.src[1].arg+1, stmt)
else: bounds[s][1] = (stmt.src[1].arg-1, stmt)

for v in bounds.values():
# some expr has lower bound > upper bound -> valid is an empty set
if v[0] is not None and v[1] is not None and v[0][0] > v[1][0]:
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))

# next, parse idx by the form ((X*c+d)%m, ((X*c+d)//m+e))
# parse m
m = mod.src[1].arg if (mod:=idx.src[0]).op is UOps.ALU and mod.arg is BinaryOps.MOD and mod.src[1].op is UOps.CONST else None
if not m or m != buf_dtype.shape[1]: return None
# parse idx.src[0]
d = add.src[1].arg if (add:=mod.src[0]).op is UOps.ALU and add.arg is BinaryOps.ADD and add.src[1].op is UOps.CONST else 0
mul = add.src[0] if d else add # + d is optional
c = mul.src[1].arg if mul.op is UOps.ALU and mul.arg is BinaryOps.MUL and mul.src[1].op is UOps.CONST else 1
X = mul.src[0] if c != 1 else mul # * c is optional
# parse idx.src[1]
e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0
div = add1.src[0] if e else add1
m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None
if m_ != m or div.src[0] != add: return None

lower, upper = X.vmin, X.vmax
# from valid, find the bound of X
drop_stmt = []
if X in bounds and (b0:=bounds[X][0]) is not None:
lower = b0[0]
drop_stmt.append(b0[1])
else: lower = X.vmin
if X in bounds and (b1:=bounds[X][1]) is not None:
upper = b1[0]
drop_stmt.append(b1[1])
else: upper = X.vmax

for stmt in _get_chain(valid, BinaryOps.AND):
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST:
if stmt.src[0].key == X.key: # X < c
upper = stmt.src[1].arg-1
drop_stmt.append(stmt)
elif stmt.src[0].key == (-X).key: # -X < -c -> X > c
lower = -stmt.src[1].arg+1
drop_stmt.append(stmt)

# valid is an empty set
if upper < lower: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))

# If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid.
new_indx0, new_indx1 = None, None
if (L:=(lower * c + d)) // m == (U:=(upper * c + d)) // m: # in the same row
if (L % m - c < 0) and (U % m + c >= m): # spans the whole row
new_indx0 = graph_rewrite(mul - ((L // m) * m - d), constant_folder)

# Because (X * c + d) % m spans the whole row, (X * c + d) // m has a fixed value.
# check if idx1 is a div that can be simplified. idx1 = (add // m + e)
e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0
div = add1.src[0] if e else add1
m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None
if m_ == m and div.src[0] == add: new_indx1 = idx.src[1].const_like(L // m + e)
new_indx1 = idx.src[1].const_like(L // m + e)

if new_indx0 and new_indx1:
new_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_indx0, new_indx1))
Expand Down

0 comments on commit 1ec6bd5

Please sign in to comment.