Skip to content

Commit

Permalink
split validhack into simplify idx and drop valids (tinygrad#6719)
Browse files Browse the repository at this point in the history
* split validhack into simplify idx and drop valids

will be using the simplify idx for non-image buffer
[run_process_replay]

* shorter
  • Loading branch information
chenyuxyz authored Sep 24, 2024
1 parent cefc3e9 commit 5c240c3
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")

def simplify_valid_image_load(load:UOp, buf:UOp):
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
buf, idx, invalid_val, valid = load.src
start_idx = idx
def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
# return None if valid is always False, otherwise the simplified idx (might be the same as input)

# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
Expand All @@ -212,9 +210,8 @@ def simplify_valid_image_load(load:UOp, buf:UOp):

# simplify idx given that valid is True
for uop,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set and we return early
if v[0] is not None and v[1] is not None and v[0] > v[1]:
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None

# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
candidates = []
Expand All @@ -234,6 +231,12 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
# if every branch in candidate gives the same simplified output, we can rewrite the idx
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
return idx

def simplify_valid_image_load(load:UOp, buf:UOp):
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=idx_given_valid(valid, start_idx)) is None: return UOp(UOps.LOAD, load.dtype, (buf, start_idx, invalid_val, valid.const_like(False)))

# can drop valid if idx is out of bound when valid is False
drop_stmt = []
Expand All @@ -259,10 +262,9 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
drop_stmt.append(stmt)
break

if drop_stmt or idx.key != start_idx.key:
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))
return None
if not drop_stmt and idx.key == start_idx.key: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))

# ***** transcendental *****

Expand Down

0 comments on commit 5c240c3

Please sign in to comment.