Skip to content

Commit

Permalink
don't raise ValueError in uop_given_valid [pr] (tinygrad#7163)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 20, 2024
1 parent 98de582 commit a9ab7db
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
for stmt in split_uop(valid, BinaryOps.AND):
expr, is_upper, c = parse_valid(stmt)
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c

# simplify uop given that valid is True
Expand Down Expand Up @@ -135,19 +136,15 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
def simplify_valid(valid:UOp) -> Optional[UOp]:
ret:List[UOp] = []
something_changed = False
try:
for stmt in split_uop(valid, BinaryOps.AND):
ret.append(stmt if not ret else uop_given_valid(functools.reduce(operator.and_, ret), stmt))
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None
except ValueError: return None
for stmt in split_uop(valid, BinaryOps.AND):
ret.append(stmt if not ret else uop_given_valid(functools.reduce(operator.and_, ret), stmt))
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None

def simplify_buffer_load(load:UOp) -> Optional[UOp]:
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
try:
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
except ValueError: return None
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))

def simplify_image_load(load:UOp) -> Optional[UOp]:
Expand Down

0 comments on commit a9ab7db

Please sign in to comment.