Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demonstrate how to add JIT using MLIR to micrograd #62

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
078afbc
ignore venv
fzakaria Jan 12, 2024
23c7fac
Added more to gitignore
fzakaria Jan 20, 2024
edcde3d
Minor fixes to engine.py for passing parent information for __pow__
fzakaria Jan 20, 2024
f1cab09
Add a visitor
fzakaria Jan 20, 2024
b4cbbcd
Add a test visitor
fzakaria Jan 20, 2024
ff2593b
Add test_mlir_execution.py
alexander-shaposhnikov Feb 28, 2024
621328a
Simplify the pipeline
alexander-shaposhnikov Feb 28, 2024
25138a0
Add jit.py, test_jit.py
alexander-shaposhnikov Feb 29, 2024
67516d4
Merge pull request #1 from alexander-shaposhnikov/add_mlir_exec
fzakaria Feb 29, 2024
574f731
Merge pull request #2 from alexander-shaposhnikov/add_compiler
fzakaria Feb 29, 2024
fc6c096
Add direnv integration
fzakaria Feb 29, 2024
b27f24d
Add requirements.txt file
fzakaria Feb 29, 2024
518e151
Added numpy to requirements
fzakaria Feb 29, 2024
440f14a
Add __init__.py to test directory
fzakaria Feb 29, 2024
0bfcaf3
Fixup some tests to test against non JIT
fzakaria Feb 29, 2024
2213aea
Changed README for how to invoke pytest
fzakaria Feb 29, 2024
84cb555
Added a JIT callable to print mlir
fzakaria Feb 29, 2024
87489d7
Refinements
fzakaria Feb 29, 2024
0082cc0
Fix it so that JIT works for multiple out NN
fzakaria Mar 1, 2024
c64757b
Remove extra argument to execution_engine
fzakaria Mar 1, 2024
284c2f1
Add documentation
fzakaria Mar 1, 2024
e79081b
Added a new cell to the demo
fzakaria Mar 1, 2024
0b5ac71
Add a toy benchmark
alexander-shaposhnikov Mar 2, 2024
e8b5d5a
Merge pull request #4 from alexander-shaposhnikov/add_toy_benchmark
fzakaria Mar 2, 2024
d3e5e4d
Format test_jit.py
fzakaria Mar 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
layout python3
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
.ipynb_checkpoints/
**/__pycache__/
venv/
.direnv/
55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,60 @@ dot = draw_dot(y)
To run the unit tests you will have to install [PyTorch](https://pytorch.org/), which the tests use as a reference for verifying the correctness of the calculated gradients. Then simply:

```bash
python -m pytest
pytest
```

### Dependencies

There is a `requirements.txt` with the necessary dependencies.

```bash
pip install -r requirements.txt
```

### Just in Time Compilation

This repository also contains a JIT compiler for the micrograd engine using [mlir](https://mlir.llvm.org/) which is then lowered to LLVM IR and executed with a provided
CPU backend.

```python
def test_value():
a = Value(4.0)
b = Value(2.0)
c = a + b # 6.
d = a + c # 10.
jd = jit(d)
assert math.isclose(d.data, jd(), abs_tol=1e-04)

def test_mlp():
random.seed(10)
nn = MLP(nin=2, nouts=[1])
jnn = jit(nn)
args = [-30., -20.]
assert math.isclose(nn(args).data, jnn(args), abs_tol=1e-04)
```

You can also print the JIT object returned to see the corresponding MLIR IR.
```python
>>> from micrograd.engine import Value
>>> from micrograd.jit import jit
>>> a = Value(4.0)
>>> b = Value(2.0)
>>> c = a + b
>>> jit_c = jit(c)
>>> print(jit_c)
module {
llvm.func @main() -> f32 attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.constant(4.000000e+00 : f32) : f32
%1 = llvm.mlir.constant(2.000000e+00 : f32) : f32
%2 = llvm.mlir.constant(6.000000e+00 : f32) : f32
llvm.return %2 : f32
}
llvm.func @_mlir_ciface_main() -> f32 attributes {llvm.emit_c_interface} {
%0 = llvm.call @main() : () -> f32
llvm.return %0 : f32
}
}
```

### License
Expand Down
3,411 changes: 3,323 additions & 88 deletions demo.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions micrograd/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

class Value:
""" stores a single scalar value and its gradient """

Expand Down Expand Up @@ -34,7 +33,7 @@ def _backward():

def __pow__(self, other):
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
out = Value(self.data**other, (self,), f'**{other}')
out = Value(self.data**other, (self, Value(other)), f'**{other}')

def _backward():
self.grad += (other * self.data**(other-1)) * out.grad
Expand Down Expand Up @@ -91,4 +90,4 @@ def __rtruediv__(self, other): # other / self
return other * self**-1

def __repr__(self):
return f"Value(data={self.data}, grad={self.grad})"
return f"Value(data={self.data}, grad={self.grad})"
170 changes: 170 additions & 0 deletions micrograd/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""This is a small JIT compiler for micrograd computation graphs using MLIR.

The MLIR is lowered to LLVM IR and then executed using an LLVM JIT engine.
The comments in the file are meant to be liberal as this is a demonstration
and learning project.
"""

from micrograd.engine import Value
from micrograd.nn import Neuron, Layer, MLP
import mlir.dialects.arith as arith
import mlir.dialects.math as math
import mlir.dialects.func as func
from mlir.ir import Context, Location, InsertionPoint, Module
from mlir.execution_engine import ExecutionEngine
from mlir.passmanager import PassManager
from mlir import ir
from typing import Union, Optional
import math
from ctypes import c_float, byref, pointer


class Compiler:
"""Compiler for a micrograd computation Value graph to MLIR arithmetic dialect."""

def __init__(self, compiled_values={}):
self.compiled_values = compiled_values

def walk(self, value: Value) -> ir.Value:
"""Walk the Value graph and convert it an isomorphic MLIR arithmetic dialect graph."""

if value in self.compiled_values:
return self.compiled_values[value]
match value._op:
case "":
return arith.constant(value=float(value.data), result=ir.F32Type.get())
case "*":
lhs, rhs = value._prev
return arith.mulf(self.walk(lhs), self.walk(rhs))
case "+":
lhs, rhs = value._prev
return arith.addf(self.walk(lhs), self.walk(rhs))
case "ReLU":
(item,) = value._prev
return arith.maximumf(self.walk(Value(0.0)), self.walk(item))
if "**" in value._op:
base, exp = value._prev
return math.powf(self.walk(base), self.walk(exp))


def _get_args_num(net: Union[Value, Neuron, Layer, MLP]) -> int:
if isinstance(net, Neuron):
return len(net.parameters()) - 1
if isinstance(net, Layer):
return _get_args_num(net.neurons[0])
if isinstance(net, MLP):
return _get_args_num(net.layers[0])
assert isinstance(net, Value)
return 0


def _get_results_num(net: Union[Value, Neuron, Layer, MLP]) -> int:
if isinstance(net, Layer):
return len(net.neurons)
if isinstance(net, MLP):
return _get_results_num(net.layers[-1])
assert isinstance(net, Value) or isinstance(net, Neuron)
return 1


def _compile(net: Union[Value, Neuron, Layer, MLP]):
"""Adds the main method to a MLIR module.

This function assumes it is called within a context and insertion point.
"""
args_num = _get_args_num(net)
args_types = [ir.F32Type.get()] * args_num
args_values = [Value(0) for _ in range(args_num)]

@func.func(*args_types)
def main(*args):
# This is a bit of a hack to figure out the computation graph.
# Rather than model the various remaining types such as
# Neuron, Layer, and MLP, we instead execute the computation
# and since the result is a Value it encodes the whole graph.
# This is OK since the point of JIT is to speedup subsequent
# executions.
net_value = net if isinstance(net, Value) else net(args_values)
# The computation graph earlier was created with seed values of Value(0).
# We now need to replace these with the actual arguments provided to the
# MLIR main function.
# We accomplish this by creating a mapping from the seed values to the
# compiled arguments (cv). The walk method will replace the seed values
# when traversing the graph wth the actual arguments
compiled_values = {v: cv for v, cv in zip(args_values, args)}
compiler = Compiler(compiled_values)
if isinstance(net_value, list):
return [compiler.walk(value) for value in net_value]
return compiler.walk(net_value)

main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()


def _compile_standalone(net: Union[Value, Neuron, Layer, MLP]) -> ir.Module:
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
_compile(net)
return module


def _lower_to_llvm(mod: ir.Module) -> ir.Module:
"""Lower the MLIR module to LLVM.

The assumption is that the module only uses standard
dialects that can be lowered to LLVM.
"""
pm = PassManager.parse("builtin.module(convert-to-llvm)", context=mod.context)
pm.run(mod.operation)
return mod


class JittedNet:
def __init__(
self,
net: Union[Value, Neuron, Layer, MLP],
m: ir.Module,
execution_engine: ExecutionEngine,
):
self.net = net
self.m = m
self.execution_engine = execution_engine

def __call__(self, x: Optional[list[float]] = None):
if isinstance(self.net, Value) and x != None:
raise "You should not pass any arguments to a Value."
xs = [] if isinstance(self.net, Value) else x

args = [byref(c_float(v)) for v in xs]

num_results = _get_results_num(self.net)
FloatResultArrayType = c_float * num_results
res = FloatResultArrayType(-1)

# ExecutionEngine has odd semantics if an argument is a pointer.
# Some networks can return a single value, others a list.
# This also changes the type of MLIR that is lowered to LLVM such that the
# return value must be in argument to the function now.
# https://github.com/llvm/llvm-project/issues/83599
if num_results == 1:
args = args + [byref(res)]
else:
args = [pointer(pointer(res))] + args

self.execution_engine.invoke("main", *args)
return res[0] if num_results == 1 else [res[i] for i in range(num_results)]

def __str__(self):
return str(self.m)


def jit(net: Union[Value, Neuron, Layer, MLP]) -> JittedNet:
"""Given a micrograd computation graph, compile it to MLIR and then to LLVM.

You can also print the returned object to see the MLIR module.

@return: a callable that takes the input arguments of the computation graph
"""
m = _compile_standalone(net)
execution_engine = ExecutionEngine(_lower_to_llvm(m))
return JittedNet(net, m, execution_engine)
29 changes: 29 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
filelock==3.13.1
fsspec==2024.2.0
iniconfig==2.0.0
Jinja2==3.1.3
MarkupSafe==2.1.5
--find-links https://makslevental.github.io/wheels
mlir-python-bindings==19.0.0.2024022901+vulkan.0fe4b9da
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
packaging==23.2
pluggy==1.4.0
pytest==8.0.2
sympy==1.12
torch==2.2.1
triton==2.2.0
typing_extensions==4.10.0
Empty file added test/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math
import random
import timeit
from micrograd.engine import Value
from micrograd.nn import Neuron, Layer, MLP
from micrograd.jit import jit

# helps investigate segmentation faults
import faulthandler

faulthandler.enable()


def test_value():
a = Value(4.0)
b = Value(2.0)
c = a + b # 6.
d = a + c # 10.
jd = jit(d)
assert math.isclose(d.data, jd(), abs_tol=1e-04)


def test_neuron():
n = Neuron(nin=1, nonlin=False)
n.w = [2.0]
jn = jit(n)
args = [10.0]
assert math.isclose(n(args).data, jn(args), abs_tol=1e-04)


def test_layer():
random.seed(10)
l = Layer(nin=2, nout=1)
jl = jit(l)
args = [-30.0, -20.0]
assert math.isclose(l(args).data, jl(args), abs_tol=1e-04)


def test_layer_multiple_out():
random.seed(10)
l = Layer(nin=2, nout=2)
jl = jit(l)
args = [-30.0, -20.0]
for r, jr in zip(l(args), jl(args)):
assert math.isclose(r.data, jr, abs_tol=1e-04)


def test_mlp():
random.seed(10)
nn = MLP(nin=2, nouts=[1])
jnn = jit(nn)
args = [-30.0, -20.0]
assert math.isclose(nn(args).data, jnn(args), abs_tol=1e-04)


def test_mlp_complex():
random.seed(10)
nn = MLP(nin=2, nouts=[2, 1])
jnn = jit(nn)
args = [-30.0, -20.0]
assert math.isclose(nn(args).data, jnn(args), abs_tol=1e-04)


def test_mlp_complex_multiple_out():
random.seed(10)
nn = MLP(nin=2, nouts=[2, 2])
jnn = jit(nn)
args = [-30.0, -20.0]
for r, jr in zip(nn(args), jnn(args)):
assert math.isclose(r.data, jr, abs_tol=1e-04)


def test_mlp_performance():
random.seed(10)
nn = MLP(nin=10, nouts=[30, 20, 10, 1])
args = random.sample(range(-100, 100), 10)
jnn = jit(nn)

def slow_inference():
return nn(args)

def fast_inference():
return jnn(args)

slow_inference_time = timeit.timeit(slow_inference, number=1000)
fast_inference_time = timeit.timeit(fast_inference, number=1000)
print(f"\nslow: {slow_inference_time}\nfast: {fast_inference_time}")
assert slow_inference_time > fast_inference_time