Skip to content

Commit

Permalink
Merge branch 'master' into uint-abs-diff
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored May 13, 2024
2 parents f4d51ad + 25ec40c commit 6ca5af4
Show file tree
Hide file tree
Showing 118 changed files with 3,178 additions and 1,839 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ jobs:
JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run GPT2 w HALF
run: JIT=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
# TODO: this is flaky
# - name: Run GPT2 w HALF/BEAM
# run: JIT=0 HALF=1 BEAM=2 CACHELEVEL=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Run GPT2 w HALF/BEAM
run: JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
Expand Down Expand Up @@ -142,7 +141,7 @@ jobs:
- name: Run GPT2 w HALF
run: CUDA=1 JIT=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: CUDA=1 JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 JIT_BATCH_SIZE=4 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: CUDA=1 JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 JIT_BATCH_SIZE=4 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. CUDA=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
Expand Down
39 changes: 6 additions & 33 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,39 +115,6 @@ jobs:
- name: Repo line count <8000 lines
run: MAX_LINE_COUNT=8000 python sz.py

testcpuimagenet:
name: ImageNet to C Tests
runs-on: ubuntu-latest
timeout-minutes: 20

steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v5
with:
python-version: 3.8
- name: Cache python packages
uses: actions/cache@v4
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache downloads
uses: actions/cache@v4
with:
path: ~/.cache/tinygrad/downloads/
key: downloads-cache-cpu-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
#- name: Run Pytest
# run: python -m pytest -n=auto test/ -k "not (test_efficientnet and models/test_train.py)" --durations=20
- name: Compile EfficientNet to C
run: PYTHONPATH="." CLANG=1 python examples/compile_efficientnet.py > recognize.c
- name: Compile C to native
run: clang -O2 recognize.c -lm -o recognize
- name: Test EfficientNet
run: cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock

testopencl:
strategy:
fail-fast: false
Expand Down Expand Up @@ -487,6 +454,12 @@ jobs:
- name: Run pytest (hip)
if: matrix.backend=='hip'
run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/external/external_test_hip_compile.py --durations=20
- name: Compile EfficientNet to C and test it
if: matrix.backend=='clang'
run: |
PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
clang -O2 recognize.c -lm -o recognize
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
#testunicorn:
# name: ARM64 unicorn Test
Expand Down
7 changes: 4 additions & 3 deletions docs-legacy/abstractions2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))

# convert the computation to a "linearized" format (print the format)
lin = Device[DEVICE].get_linearizer(st_0).linearize()
from tinygrad.engine.realize import get_linearizer, CompiledRunner
lin = get_linearizer(Device[DEVICE].renderer, (st_0,)).linearize()
for u in lin.uops: print(u)

# compile a program (and print the source)
fxn = Device[DEVICE].to_program(lin)
print(fxn.prg)
fxn = CompiledRunner(lin.to_program())
print(fxn.p.src)
# NOTE: fxn.clprg is the ClangProgram

# run the program
Expand Down
2 changes: 1 addition & 1 deletion docs-legacy/abstractions3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def model(x): return x.flatten(1).dot(l1.T).relu().dot(l2.T)
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
# l1.lazydata and l2.lazydata define a computation graph

from tinygrad.ops import ScheduleItem
from tinygrad.engine.schedule import ScheduleItem
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)

print(f"The schedule contains {len(schedule)} items.")
Expand Down
12 changes: 5 additions & 7 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@ The `LazyBuffer` graph specifies the compute in terms of low level tinygrad ops.

## Scheduling

The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.

::: tinygrad.ops.ScheduleItem
::: tinygrad.engine.schedule.ScheduleItem

## Lowering

The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with

::: tinygrad.engine.realize.lower_schedule

There's a ton of complexity hidden behind this, see the `codegen/` directory.

First we lower the AST to UOps, which is a linear list of the compute to be run. This is where the BEAM search happens. The UOps can be changed by `CompilerOptions`.
First we lower the AST to UOps, which is a linear list of the compute to be run. This is where the BEAM search happens.

::: tinygrad.device.CompilerOptions

Then we render the UOps into code, then we compile the code to binary.
Then we render the UOps into code with a `Renderer`, then we compile the code to binary with a `Compiler`.

## Execution

Expand Down
26 changes: 13 additions & 13 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ randn = Tensor.randn(2, 3) # create a tensor of shape (2, 3) filled with random
uniform = Tensor.uniform(2, 3, low=0, high=10) # create a tensor of shape (2, 3) filled with random values from a uniform distribution between 0 and 10
```

There are even more of these factory methods, you can find them in the [tensor.py](/tinygrad/tensor.py) file.
There are even more of these factory methods, you can find them in the [Tensor](tensor.md) file.

All the tensors creation methods can take a `dtype` argument to specify the data type of the tensor.

Expand All @@ -75,8 +75,8 @@ print(t6.numpy())
# [-56. -48. -36. -20. 0.]
```

There are a lot more operations that can be performed on tensors, you can find them in the [tensor.py](/tinygrad/tensor.py) file.
Additionally reading through [abstractions2.py](/docs-legacy/abstractions2.py) will help you understand how operations on these tensors make their way down to your hardware.
There are a lot more operations that can be performed on tensors, you can find them in the [Tensor](tensor.md) file.
Additionally reading through [abstractions2.py](https://github.com/tinygrad/tinygrad/blob/master/docs-legacy/abstractions2.py) will help you understand how operations on these tensors make their way down to your hardware.

## Models

Expand All @@ -96,7 +96,7 @@ class Linear:
return x.linear(self.weight.transpose(), self.bias)
```

There are more neural network modules already implemented in [nn](/tinygrad/nn/__init__.py), and you can also implement your own.
There are more neural network modules already implemented in [nn](nn.md), and you can also implement your own.

We will be implementing a simple neural network that can classify handwritten digits from the MNIST dataset.
Our classifier will be a simple 2 layer neural network with a Leaky ReLU activation function.
Expand Down Expand Up @@ -126,9 +126,9 @@ Finally, we just initialize an instance of our neural network, and we are ready
Now that we have our neural network defined we can start training it.
Training neural networks in tinygrad is super simple.
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
They can then be used to update the parameters of our neural network using one of the many [Optimizers](nn.md#optimizers).

For our loss function we will be using sparse categorical cross entropy loss. The implementation below is taken from [tensor.py](/tinygrad/tensor.py), it's copied below to highlight an important detail of tinygrad.
For our loss function we will be using sparse categorical cross entropy loss. The implementation below is taken from [tensor.py](https://github.com/tinygrad/tinygrad/blob/master/tinygrad/tensor.py), it's copied below to highlight an important detail of tinygrad.

```python
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
Expand Down Expand Up @@ -156,7 +156,7 @@ There is a simpler way to do this just by using `get_parameters(net)` from `tiny
The parameters are just listed out explicitly here for clarity.

Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
There are a couple of dataset loaders in tinygrad located in [/extra/datasets](/extra/datasets).
There are a couple of dataset loaders in tinygrad located in [/extra/datasets](https://github.com/tinygrad/tinygrad/blob/master/extra/datasets).
We will be using the MNIST dataset loader.

```python
Expand Down Expand Up @@ -229,10 +229,10 @@ with Timing("Time: "):

## And that's it

Highly recommend you check out the [examples/](/examples) folder for more examples of using tinygrad.
Highly recommend you check out the [examples/](https://github.com/tinygrad/tinygrad/blob/master/examples) folder for more examples of using tinygrad.
Reading the source code of tinygrad is also a great way to learn how it works.
Specifically the tests in [test/](/test) are a great place to see how to use and the semantics of the different operations.
There are also a bunch of models implemented in [models/](/extra/models) that you can use as a reference.
Specifically the tests in [test/](https://github.com/tinygrad/tinygrad/blob/master/test) are a great place to see how to use and the semantics of the different operations.
There are also a bunch of models implemented in [models/](https://github.com/tinygrad/tinygrad/blob/master/extra/models) that you can use as a reference.

Additionally, feel free to ask questions in the `#learn-tinygrad` channel on the [discord](https://discord.gg/beYbxwxVdx). Don't ask to ask, just ask!

Expand Down Expand Up @@ -276,7 +276,7 @@ You will find that the evaluation time is much faster than before and that your
### Saving and Loading Models

The standard weight format for tinygrad is [safetensors](https://github.com/huggingface/safetensors). This means that you can load the weights of any model also using safetensors into tinygrad.
There are functions in [state.py](/tinygrad/nn/state.py) to save and load models to and from this format.
There are functions in [state.py](https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/state.py) to save and load models to and from this format.

```python
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
Expand All @@ -292,14 +292,14 @@ state_dict = safe_load("model.safetensors")
load_state_dict(net, state_dict)
```

Many of the models in the [models/](/models) folder have a `load_from_pretrained` method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.
Many of the models in the [models/](https://github.com/tinygrad/tinygrad/tree/master/extra/models) folder have a `load_from_pretrained` method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.

### Environment Variables

There exist a bunch of environment variables that control the runtime behavior of tinygrad.
Some of the commons ones are `DEBUG` and the different backend enablement variables.

You can find a full list and their descriptions in [env_vars.md](/docs-legacy/env_vars.md).
You can find a full list and their descriptions in [env_vars.md](https://github.com/tinygrad/tinygrad/blob/master/docs-legacy/env_vars.md).

### Visualizing the Computation Graph

Expand Down
12 changes: 6 additions & 6 deletions docs/showcase.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ python3 examples/efficientnet.py webcam

### YOLOv8

Take a look at [yolov8.py](/examples/yolov8.py).
Take a look at [yolov8.py](https://github.com/tinygrad/tinygrad/tree/master/examples/yolov8.py).

![yolov8 by tinygrad](showcase/yolov8_showcase_image.png)
![yolov8 by tinygrad](https://github.com/tinygrad/tinygrad/tree/master/docs/showcase/yolov8_showcase_image.png)

## Audio

### Whisper

Take a look at [whisper.py](/examples/whisper.py). You need pyaudio and torchaudio installed.
Take a look at [whisper.py](https://github.com/tinygrad/tinygrad/tree/master/examples/whisper.py). You need pyaudio and torchaudio installed.

```sh
SMALL=1 python3 examples/whisper.py
Expand All @@ -35,17 +35,17 @@ SMALL=1 python3 examples/whisper.py

### Generative Adversarial Networks

Take a look at [mnist_gan.py](/examples/mnist_gan.py).
Take a look at [mnist_gan.py](https://github.com/tinygrad/tinygrad/tree/master/examples/mnist_gan.py).

![mnist gan by tinygrad](showcase/mnist_by_tinygrad.jpg)
![mnist gan by tinygrad](https://github.com/tinygrad/tinygrad/tree/master/docs/showcase/mnist_by_tinygrad.jpg)

### Stable Diffusion

```sh
python3 examples/stable_diffusion.py
```

![a horse sized cat eating a bagel](showcase/stable_diffusion_by_tinygrad.jpg)
![a horse sized cat eating a bagel](https://github.com/tinygrad/tinygrad/tree/master/docs/showcase/stable_diffusion_by_tinygrad.jpg)

*"a horse sized cat eating a bagel"*

Expand Down
7 changes: 4 additions & 3 deletions examples/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def infer(model, img):
cv2.destroyAllWindows()
else:
img = Image.open(fetch(url))
with Timing("did inference in "):
out, _ = infer(model, img)
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
for i in range(getenv("CNT", 1)):
with Timing("did inference in "):
out, _ = infer(model, img)
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
36 changes: 24 additions & 12 deletions examples/handcode_resnet50_opt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad import Tensor, nn
from tinygrad.ops import LoadOps, get_lazyop_info
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
Expand All @@ -9,6 +9,7 @@
from tinygrad.shape.symbolic import sym_infer
from tinygrad.dtype import dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.features.graph import print_tree

if __name__ == "__main__":
if getenv("HALF"):
Expand All @@ -19,15 +20,21 @@

# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
if getenv("BACKWARD"):
Tensor.training = True
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
print(f"optimizing for {Device.DEFAULT}")

# first model run to init the weights, they are saved in seen
create_schedule([mdl(Tensor.empty(64, 3, 224, 224)).lazydata], seen)

# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = create_schedule([out.lazydata], seen)
# run model twice to get only what changes, these are the kernels of the model
for i in range(2):
out = mdl(Tensor.empty(64, 3, 224, 224))
targets = [out.lazydata]
if getenv("BACKWARD"):
optim.zero_grad()
out.sparse_categorical_crossentropy(Tensor.empty(64, dtype=dtypes.int)).backward()
targets += [x.lazydata for x in optim.schedule_step()]
sched = create_schedule(targets, seen)
print(f"schedule length {len(sched)}")
sched = [x for x in sched if x.ast[0].op not in LoadOps]

# focus on one kernel
Expand All @@ -37,32 +44,37 @@
total_tm = 0
running_gflops = 0
for i,si in enumerate(sched):
ops = sum(get_lazyop_info(ast).flops for ast in si.ast)

if DEBUG >= 2:
for ast in si.ast: print_tree(ast)

rawbufs = bufs_from_lin(Linearizer(*si.ast))

# "linearize" the op into uops in different ways
lins:List[Linearizer] = []

# always try hand coded opt
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lins.append(lin)

# maybe try tensor cores
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append(lin)

# try a beam search
if beam:=getenv("BEAM"):
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)

# benchmark the programs
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = sym_infer(get_lazyop_info(lin.ast[0]).flops, {k:k.min for k in lin.ast[0].vars()})*1e-9/tm
gflops = sym_infer(ops, {k:k.min for k in lin.ast[0].vars()})*1e-9/tm
choices.append((tm, gflops, lin.linearize()))

# print all kernels
Expand Down
2 changes: 1 addition & 1 deletion examples/hlb_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def calc_stats(self, x:Tensor):
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(1,3,4))
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=(1,3,4))
batch_invstd = batch_var.add(self.eps).pow(-0.5)

Expand Down
Loading

0 comments on commit 6ca5af4

Please sign in to comment.