Skip to content

Commit

Permalink
update with the latest from tinygrad
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 24, 2023
1 parent 7659ede commit ac4e379
Showing 6 changed files with 197 additions and 111 deletions.
4 changes: 4 additions & 0 deletions teenygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_int(t: Tuple[Any, ...]) -> Tuple[int, ...]: return all(isinstance(s, int) for s in t)
def round_up(num, amt:int): return (num+amt-1)//amt * amt

@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
@@ -25,10 +26,13 @@ class DType(NamedTuple):
def __repr__(self): return f"dtypes.{self.name}"

class dtypes:
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def is_float(x: DType) -> bool: return x == dtypes.float32
float32: Final[DType] = DType(4, 4, "float", np.float32)
int32: Final[DType] = DType(2, 1, "int32", np.int32)
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

ImageDType, IMAGE = None, None # junk to remove
11 changes: 8 additions & 3 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -4,21 +4,26 @@
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

class RawCPUBuffer:
def __init__(self, x): self.x = x
def toCPU(self): return self.x

class LazyBuffer:
device = "CPU"
dtype = dtypes.float32
realized = None

def __init__(self, buf): self._np = buf

@property
def realized(self): return RawCPUBuffer(self._np)
@property
def shape(self): return self._np.shape

def realize(x): return x
def schedule(self, seen=None): return []
def is_unrealized_const(self): return False

@staticmethod
def fromCPU(x): return LazyBuffer(x)
def toCPU(self): return self._np

@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
46 changes: 23 additions & 23 deletions teenygrad/mlops.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,10 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.con
class Cast(Function):
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.e(UnaryOps.CAST, arg=(dtype, bitcast))
return x.cast(dtype, bitcast)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast))
return grad_output.cast(self.input_dtype, self.bitcast)

# ************* unary ops *************

@@ -84,27 +84,6 @@ def forward(self, x:LazyBuffer) -> LazyBuffer:
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)

# ************* reduce ops *************

class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)

class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))

# ************* binary ops *************

class Less(Function):
@@ -157,6 +136,27 @@ def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer],
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None

# ************* reduce ops *************

class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)

class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))

# ************* movement ops *************

# NOTE: this is sum in reverse
8 changes: 4 additions & 4 deletions teenygrad/nn/optim.py
Original file line number Diff line number Diff line change
@@ -10,17 +10,17 @@ def __init__(self, params: List[Tensor], lr: float):
if x.requires_grad is None: x.requires_grad = True

self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = Tensor([lr], requires_grad=False).contiguous()
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()

def zero_grad(self):
for param in self.params: param.grad = None

def realize(self, extra=None):
# TODO: corealize
# NOTE: in extra is too late for most of the params due to issues with assign
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
p.realize()
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)

class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
1 change: 1 addition & 0 deletions teenygrad/realize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def run_schedule(schedule, disable_logging=False): pass
Loading

0 comments on commit ac4e379

Please sign in to comment.