Skip to content

Commit

Permalink
move image into tensor.py. delete features (tinygrad#4603)
Browse files Browse the repository at this point in the history
* move image into tensor.py

* change setup.py

* openpilot tests need pythonpath now
  • Loading branch information
geohot authored May 15, 2024
1 parent 36d2ac6 commit 5ba6117
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 113 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,17 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
#python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot alt model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot fastvits model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'onnx' }}
name: Test ONNX (GPU)
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
Expand Down
2 changes: 1 addition & 1 deletion examples/hlb_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.multi import MultiLazyBuffer

BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
EVAL_BS = getenv("EVAL_BS", BS)
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
long_description=long_description,
long_description_content_type='text/markdown',
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.engine',
'tinygrad.runtime', 'tinygrad.runtime.driver', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.features'],
'tinygrad.runtime', 'tinygrad.runtime.driver', 'tinygrad.runtime.graph', 'tinygrad.shape'],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
Expand Down Expand Up @@ -59,7 +59,8 @@
"mkdocs-material",
"mkdocstrings[python]",
"markdown-callouts",
"markdown-exec[ansi]"
"markdown-exec[ansi]",
"black"
],
'testing_tf': [
"tensorflow==2.15.1",
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_benchmark_multitensor_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tinygrad import Tensor, Device, GlobalCounters, TinyJit
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import ReduceOps
from tinygrad.features.multi import MultiLazyBuffer, all_reduce
from tinygrad.multi import MultiLazyBuffer, all_reduce
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.helpers import getenv, Context, RING
Expand Down
2 changes: 1 addition & 1 deletion test/test_image_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from tinygrad import Device, dtypes, Tensor, Variable
from tinygrad.dtype import ImageDType
from tinygrad.features.image import to_image_idx
from tinygrad.codegen.linearizer import to_image_idx

@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")
class TestImageDType(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner
from tinygrad.features.multi import all_reduce, MultiLazyBuffer
from tinygrad.multi import all_reduce, MultiLazyBuffer
from random import randint
import numpy as np
from hypothesis import given, strategies as strat, settings
Expand Down
7 changes: 6 additions & 1 deletion tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.features.image import to_image_idx
from tinygrad.renderer import Program

from tinygrad.codegen.uops import UOps, UOp, UOpGraph
Expand All @@ -32,6 +31,12 @@ def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))

def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
# TODO: bring back the valid removal logic (correct!)
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
return (idx, idy), valid

# expand a Node into List[Node] that enumerates the underlying Variables from min to max
# expand increments earlier variables faster than later variables (as specified in the argument)
@functools.lru_cache(maxsize=None)
Expand Down
93 changes: 0 additions & 93 deletions tinygrad/features/image.py

This file was deleted.

4 changes: 2 additions & 2 deletions tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuf
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
if op in BinaryOps: x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment
if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if op is BinaryOps.MUL:
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/features/multi.py → tinygrad/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def copy_to_device(self, device:str) -> LazyBuffer:
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)

# passthroughs
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
from tinygrad.shape.view import strides_for_shape
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.multi import MultiLazyBuffer

safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
Expand Down
88 changes: 84 additions & 4 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Device
Expand Down Expand Up @@ -1374,14 +1374,94 @@ def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]

# *** image Tensor function replacements ***

def image_dot(self, w:Tensor, acc_dtype=None):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )

# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)

def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef

(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)

# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)

# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))

# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)

# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()

# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)

# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))

# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)

# prepare weights
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))

# the conv!
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)

# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
cout = groups * (rcout - added_output_channels)

# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))

# register functions to move between devices
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))

if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
setattr(Tensor, "conv2d", Tensor.image_conv2d)
setattr(Tensor, "dot", Tensor.image_dot)

# TODO: eventually remove this
def custom_random(out:Buffer):
Expand Down

0 comments on commit 5ba6117

Please sign in to comment.