Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conv kernel #292

Draft
wants to merge 60 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
c288f40
add python multiset util
jorendumoulin Oct 22, 2024
cfb0a81
add memory layout for conv
jorendumoulin Oct 23, 2024
d69c37d
stream conversions: add support for conv kernels
jorendumoulin Oct 23, 2024
b9342f7
add conv kernel
jorendumoulin Oct 23, 2024
529ee49
add genkernel
jorendumoulin Nov 12, 2024
6de2300
changes
jorendumoulin Nov 12, 2024
39dc3e9
add readonly attribute
jorendumoulin Nov 12, 2024
77bd68f
add new streamer opts
jorendumoulin Nov 14, 2024
6e87c04
add stream schedule op
jorendumoulin Nov 14, 2024
ac3c387
change gemmx to isca config
jorendumoulin Nov 14, 2024
80889ef
stream pretty printing
jorendumoulin Nov 14, 2024
3135f6c
tsl simplification
jorendumoulin Nov 14, 2024
876fd84
add autoflow passes
jorendumoulin Nov 14, 2024
f877fb9
change to isca config
jorendumoulin Nov 14, 2024
e288770
autoflow layout selection
jorendumoulin Nov 14, 2024
dadc6ea
add conv genkernel
jorendumoulin Nov 14, 2024
d02e09b
conv 72 working
jorendumoulin Nov 15, 2024
15235f0
enable non-output stationary flows
jorendumoulin Nov 16, 2024
56c4cea
enable more schedules
jorendumoulin Nov 16, 2024
955afff
allow for underutilization (but check for correct memory flexibility)
jorendumoulin Nov 16, 2024
09b8ea4
works also for underutilization
jorendumoulin Nov 16, 2024
685f3a6
remove debug statements
jorendumoulin Nov 16, 2024
bfaa0f2
improve benchmark
jorendumoulin Nov 16, 2024
1fe6a9b
ready to benchmark!
jorendumoulin Nov 16, 2024
604f8fa
initial resnet additions
jorendumoulin Nov 17, 2024
18cc8e3
add extreme canonicalization
jorendumoulin Nov 17, 2024
623cf82
add support for first layer resnet
jorendumoulin Nov 17, 2024
d051a2c
first resnet layer working!
jorendumoulin Nov 17, 2024
ff29816
ready for resnet benchmark
jorendumoulin Nov 17, 2024
498c314
remove breakpoint
jorendumoulin Nov 17, 2024
6364485
ready for conv benchmark 2!
jorendumoulin Nov 19, 2024
1742e7a
parallel builds
jorendumoulin Nov 19, 2024
bf22af3
reduce probllem size
jorendumoulin Nov 19, 2024
18ee182
add main
jorendumoulin Nov 19, 2024
c7c63aa
pointwise
jorendumoulin Nov 19, 2024
8cfc058
gemm
jorendumoulin Nov 19, 2024
5c95ef4
7_7_conv
jorendumoulin Nov 19, 2024
fdf91f5
strided_conv
jorendumoulin Nov 19, 2024
4063279
dilated_conv
jorendumoulin Nov 19, 2024
34f98d7
layout benchmark gemm
jorendumoulin Nov 19, 2024
3ced4a8
more sizes!
jorendumoulin Nov 19, 2024
02853f6
layout benchmark conv
jorendumoulin Nov 19, 2024
b8a87bc
fix gemm size
jorendumoulin Nov 19, 2024
4ae69db
true output stationary
jorendumoulin Nov 19, 2024
5467180
fix conv size
jorendumoulin Nov 19, 2024
708dbf5
fix size again
jorendumoulin Nov 19, 2024
270e721
new layout strategy
jorendumoulin Nov 19, 2024
e9e6677
remove breakpoint
jorendumoulin Nov 19, 2024
c2b93cf
run conv test
jorendumoulin Nov 21, 2024
bf502b4
not pure output stationary
jorendumoulin Nov 21, 2024
970ecd8
fix resnet
jorendumoulin Nov 21, 2024
3de0fc1
pure output stationary
jorendumoulin Nov 21, 2024
aa28981
ready for resnet
jorendumoulin Nov 21, 2024
94b3524
updated conv
jorendumoulin Nov 21, 2024
69634d9
generality
jorendumoulin Nov 21, 2024
e93cc5e
add memory layout idx 1
jorendumoulin Nov 21, 2024
7d25623
temp mapping 2
jorendumoulin Nov 21, 2024
f2285fa
fix 7x7 conv
jorendumoulin Nov 21, 2024
d6dd3dc
change to strided
jorendumoulin Nov 21, 2024
fc2a9d8
dilated conv
jorendumoulin Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions compiler/accelerators/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def _generate_streamer_setup_vals(
cst = arith.Constant.from_int_and_width(stride.data, i32)
result.append(([cst], cst.result))

# address remap:
if StreamerOpts.HasAddressRemap in streamer.opts:
c0 = arith.Constant.from_int_and_width(0, i32)
result.append(([c0], c0.result))

# channel mask option
if StreamerOpts.HasChannelMask in streamer.opts:
if is_zero_pattern:
Expand All @@ -189,6 +194,11 @@ def _generate_streamer_setup_vals(
c1 = arith.Constant.from_int_and_width(1, i32)
result.append(([c1], c1.result))

for operand, streamer in enumerate(self.streamer_config.data.streamers):
if StreamerOpts.HasBroadcast in streamer.opts:
c1 = arith.Constant.from_int_and_width(1, i32)
result.append(([c1], c1.result))

return result

def get_streamer_setup_fields(self) -> Sequence[str]:
Expand All @@ -206,6 +216,8 @@ def get_streamer_setup_fields(self) -> Sequence[str]:
# temporal strides
result.extend([f"{name}_tstride_{i}" for i in range(streamer.temporal_dim)])
# options
if StreamerOpts.HasAddressRemap in streamer.opts:
result.append(f"{name}_address_remap")
if StreamerOpts.HasChannelMask in streamer.opts:
result.append(f"{name}_channel_mask")

Expand All @@ -216,6 +228,12 @@ def get_streamer_setup_fields(self) -> Sequence[str]:
if StreamerOpts.HasTranspose in streamer.opts:
result.append(f"{name}_transpose")

for streamer, name in zip(
self.streamer_config.data.streamers, self.streamer_names
):
if StreamerOpts.HasBroadcast in streamer.opts:
result.append(f"{name}_broadcast")

return result

def get_streamer_launch_fields(self) -> Sequence[str]:
Expand Down
83 changes: 65 additions & 18 deletions compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,33 @@
[
Streamer( # A
StreamerType.Reader,
temporal_dims=("n", "n", "n", "n", "n", "n"),
temporal_dims=("n", "n", "n", "n", "n", "n", "n"),
spatial_dims=("n",),
opts=(StreamerOpts.HasTranspose,),
opts=(StreamerOpts.HasTranspose, StreamerOpts.HasAddressRemap),
),
Streamer( # B
StreamerType.Reader,
temporal_dims=("n", "n", "n"),
temporal_dims=("n", "n", "n", "n", "n", "n", "n"),
spatial_dims=("n",),
opts=(StreamerOpts.HasTranspose,),
opts=(StreamerOpts.HasTranspose, StreamerOpts.HasAddressRemap),
),
Streamer( # D8
StreamerType.Writer,
temporal_dims=("r", "n", "n"),
temporal_dims=("n", "n", "n"),
spatial_dims=("n",),
opts=(StreamerOpts.HasAddressRemap,),
),
Streamer( # C
StreamerType.Reader,
temporal_dims=("r", "n", "n"),
spatial_dims=("n",),
opts=(StreamerOpts.HasChannelMask,),
temporal_dims=("n", "n", "n", "n", "n", "n", "n"),
spatial_dims=("n", "n"),
opts=(StreamerOpts.HasChannelMask, StreamerOpts.HasAddressRemap, StreamerOpts.HasBroadcast),
),
Streamer( # D32
StreamerType.Writer,
temporal_dims=("r", "n", "n"),
spatial_dims=("n",),
temporal_dims=("n", "n", "n", "n", "n", "n", "n"),
spatial_dims=("n", "n"),
opts=(StreamerOpts.HasAddressRemap,),
),
],
)
Expand All @@ -67,6 +69,7 @@ class SNAXGEMMXAccelerator(

supported_kernels = (
SupportedKernel(kernel.QMacOp, (i8, i8, i32, i32, i32)),
SupportedKernel(kernel.MacOp, (i8, i8, i32)),
SupportedKernel(kernel.AddOp, (i32, i32, i32)),
SupportedKernel(kernel.RescaleOp, (i32, i8)),
)
Expand Down Expand Up @@ -162,10 +165,6 @@ def _generate_setup_vals(

c0 = arith.Constant.from_int_and_width(0, 32)
c1 = arith.Constant.from_int_and_width(1, 32)
knm: list = [
(((cst := arith.Constant.from_int_and_width(val.data, 32)),), cst.result)
for val in op.stride_patterns.data[0].upper_bounds
]

streamer_setup_vals = list(self._generate_streamer_setup_vals(op))

Expand All @@ -174,6 +173,16 @@ def _generate_setup_vals(
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)

if isinstance(qmac := generic_op.body.block.first_op, kernel.QMacOp):
# compute knm: fix n = 1
n = 1
m = prod(x.data for x in op.stride_patterns.data[-1].upper_bounds) // n
k = prod(x.data for x in op.stride_patterns.data[0].upper_bounds) // m

knm: list = [
(((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result)
for val in (k, n, m)
]

# gemm
# bypass simd and set all related values to 0
bypassSIMD = c1.result # bypass simd
Expand Down Expand Up @@ -201,10 +210,16 @@ def _generate_setup_vals(

elif isinstance(rescale := generic_op.body.block.first_op, kernel.RescaleOp):
# extract and compute correct value for csr's based on kernel rescale op
# set k to 1
knm.insert(
0, ((cst := arith.Constant.from_int_and_width(1, 32),), cst.result)
)
# set k and n to 1
k = 1
n = 1
m = prod(x.data for x in op.stride_patterns.data[0].upper_bounds)

knm: list = [
(((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result)
for val in (k, n, m)
]

# simd
bypassSIMD = c0.result
subtractions = c0.result
Expand Down Expand Up @@ -247,6 +262,38 @@ def _generate_setup_vals(
loop_bound = arith.Constant.from_int_and_width(loop_bound, i32)
ops_to_add.append(loop_bound)

elif isinstance(mac := generic_op.body.block.first_op, kernel.MacOp):
# compute knm: fix n = 1
n = 1
m = prod(x.data for x in op.stride_patterns.data[-1].upper_bounds) // n
k = prod(x.data for x in op.stride_patterns.data[0].upper_bounds) // m

knm: list = [
(((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result)
for val in (k, n, m)
]

# gemm
# bypass simd and set all related values to 0
bypassSIMD = c1.result # bypass simd
loop_bound = c0
csr0 = c0.result
csr1 = c0.result
shift_vals = (c0.result for _ in range(2))
mult_vals = (c0.result for _ in range(8))

# get zero points for gemm
zp_a = c0
zp_b = c0

# bitwise and with 8b'11111111 to avoid the sign bits extending the 8-bit field
# when bitlist packing
ops_to_add.append(cst255 := arith.Constant.from_int_and_width(255, 32))

bitlist = list(pack_bitlist((zp_a, zp_b), [0, 8]))
ops_to_add.extend(bitlist)
subtractions = bitlist[-1].results[0]

else:
raise NotImplementedError()

Expand Down
5 changes: 5 additions & 0 deletions compiler/accelerators/streamers/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class StreamerOpts(StrEnum):
HasTranspose = "t"
# Streamer with channel mask capabilities
HasChannelMask = "c"
# Weird address remap thingy
HasAddressRemap = "r"
# Broadcasting
HasBroadcast = "b"



class StreamerFlag(StrEnum):
Expand Down
5 changes: 4 additions & 1 deletion compiler/dialects/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def equivalent_region(self) -> Region:
)
)
def equivalent_region(args: tuple[BlockArgument, ...]) -> None:
mul = arith.Muli(args[0], args[1])
assert isinstance(output_type := args[2].type, IntegerType)
inp1 = arith.ExtSIOp(args[0], output_type)
inp2 = arith.ExtSIOp(args[1], output_type)
mul = arith.Muli(inp1, inp2)
mac = arith.Addi(args[2], mul)
linalg.YieldOp(mac)

Expand Down
27 changes: 27 additions & 0 deletions compiler/dialects/snax_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence

from typing_extensions import Self
from xdsl.dialects.builtin import ArrayAttr, IndexType, IntAttr, StringAttr
from xdsl.ir import (
Attribute,
Expand Down Expand Up @@ -111,6 +112,32 @@ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
)
return (ub, ts, ss)

def collapse_dimensions(self) -> Self:
"""
Collapses multiple perfect nested compatible for loops into 1
for i = 0 .. 3
for j = 0 .. 3
a = 4 * i + j

will turn into
for i = 0 .. 16
a = 16
"""
new_temporal_strides: list[int] = []
new_upper_bounds: list[int] = []

for stride, bound in zip(self.temporal_strides.data, self.upper_bounds.data):
if bound.data == 1:
# unused dim
continue
if len(new_temporal_strides) > 0:
if new_temporal_strides[-1] * new_upper_bounds[-1] == stride.data:
new_upper_bounds[-1] *= bound.data
continue
new_upper_bounds.append(bound.data)
new_temporal_strides.append(stride.data)
return type(self)(new_upper_bounds, new_temporal_strides, self.spatial_strides)


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
Expand Down
58 changes: 54 additions & 4 deletions compiler/dialects/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
AnyShapedType,
ArrayAttr,
ContainerType,
IndexType,
IntAttr,
IntegerAttr,
ShapedType,
StringAttr,
)
Expand Down Expand Up @@ -70,8 +73,7 @@ def get_element_type(self) -> _StreamTypeElement:
return self.element_type


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
class StreamingRegionOpBase(IRDLOperation):
"""
An operation that creates streams from tensors or memrefs, which are only available to
read from within the body of the operation.
Expand All @@ -80,8 +82,6 @@ class StreamingRegionOp(IRDLOperation):
via any other access means, including extraction (e.g.: memref.view).
"""

name = "stream.streaming_region"

inputs = var_operand_def(AnyShapedType())
outputs = var_operand_def(AnyShapedType())
result_tensors = var_result_def()
Expand Down Expand Up @@ -147,6 +147,55 @@ def get_static_pattern_bounds(self) -> Iterable[int]:
tuple(self.get_static_shapes()), []
)

@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):

name = "stream.streaming_region"



@irdl_op_definition
class ScheduleOp(IRDLOperation):

name = "stream.schedule"

inputs = var_operand_def(AnyShapedType())
outputs = var_operand_def(AnyShapedType())
result_tensors = var_result_def()
patterns = prop_def(ArrayAttr[AffineMapAttr])
bounds = prop_def(ParameterDef[ArrayAttr[ArrayAttr[IntAttr]]])

body = region_def("single_block")

accelerator = opt_prop_def(StringAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def __init__(
self,
inputs: Sequence[SSAValue | Operation],
outputs: Sequence[SSAValue | Operation],
patterns: ArrayAttr[AffineMapAttr],
bounds: Sequence[Sequence[int]],
body: Region,
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
) -> None:
if isinstance(accelerator, str):
accelerator = StringAttr(accelerator)

bounds_attr = ArrayAttr(ArrayAttr(IntAttr(x) for x in y) for y in bounds)
super().__init__(
operands=[inputs, outputs],
regions=[body],
properties={
"patterns": patterns,
"accelerator": accelerator,
"bounds": bounds_attr,
},
result_types=[result_types],
)


@irdl_op_definition
class YieldOp(AbstractYieldOperation[Attribute]):
Expand Down Expand Up @@ -197,6 +246,7 @@ def __init__(
"stream",
[
StreamingRegionOp,
ScheduleOp,
GenericOp,
YieldOp,
],
Expand Down
1 change: 1 addition & 0 deletions compiler/ir/autoflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .scheduler import *
Loading
Loading