Skip to content

Commit

Permalink
use argfix in smax/smin and remove if [pr]
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 4, 2024
1 parent 4e51833 commit 1470649
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from collections import defaultdict
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker

Expand Down Expand Up @@ -180,11 +180,10 @@ def resolve(x, default:bool=True):

# smax/smin are replacements for max/min that preserve symbolic
def _suop(lst, uop_fxn, python_fxn):
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
return python_fxn(max_num)
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)

def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
Expand Down

0 comments on commit 1470649

Please sign in to comment.