Skip to content

Commit

Permalink
dialects: (csl) Small incremental fixes to dialect structure (stacked…
Browse files Browse the repository at this point in the history
… PR) (#2711)

- Make `@get_color` take an operand instead of property.
- Make `ParamOp` take a default value as operand instead of property.
  • Loading branch information
AntonLydike authored Jun 12, 2024
1 parent 8e02c21 commit c872084
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
18 changes: 12 additions & 6 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@

"csl.export"() <{type = (i32, i32) -> (), var_name = @args_no_return}> : () -> ()

%col = "csl.get_color"() <{id = 15 : i5}> : () -> !csl.color
%cst15 = arith.constant 15 : i32
%col = "csl.get_color"(%cst15) : (i32) -> !csl.color

"csl.rpc"(%col) : (!csl.color) -> ()

Expand Down Expand Up @@ -249,7 +250,8 @@ csl.func @builtins() {
%u32_value = arith.constant 120 : ui32
%f16_value = arith.constant 7.0 : f16
%f32_value = arith.constant 8.0 : f32
%col_1 = "csl.get_color"() <{id = 3 : i5}> : () -> !csl.color
%three = arith.constant 3 : i16
%col_1 = "csl.get_color"(%three) : (i16) -> !csl.color
%f16_pointer = "csl.addressof"(%f16_value) : (f16) -> !csl.ptr<f16, #csl<ptr_kind single>, #csl<ptr_const var>>
%f32_pointer = "csl.addressof"(%f32_value) : (f32) -> !csl.ptr<f32, #csl<ptr_kind single>, #csl<ptr_const var>>
%i16_pointer = "csl.addressof"(%i16_value) : (si16) -> !csl.ptr<si16, #csl<ptr_kind single>, #csl<ptr_const var>>
Expand Down Expand Up @@ -320,7 +322,8 @@ csl.func @builtins() {

"csl.module"() <{kind=#csl<module_kind layout>}> ({
%p1 = "csl.param"() <{param_name = "param_1"}> : () -> i32
%p2 = "csl.param"() <{param_name = "param_2", init_value = 1.3 : f16}> : () -> f16
%init = arith.constant 3.14 : f16
%p2 = "csl.param"(%init) <{param_name = "param_2"}> : (f16) -> f16

csl.layout {
%x_dim = arith.constant 4 : i32
Expand Down Expand Up @@ -443,7 +446,8 @@ csl.func @builtins() {
// CHECK-NEXT: comptime {
// CHECK-NEXT: @export_symbol(args_no_return, "args_no_return");
// CHECK-NEXT: }
// CHECK-NEXT: const col : color = @get_color(15);
// CHECK-NEXT: const cst15 : i32 = 15;
// CHECK-NEXT: const col : color = @get_color(cst15);
// CHECK-NEXT: comptime {
// CHECK-NEXT: @rpc(@get_data_task_id(col));
// CHECK-NEXT: }
Expand Down Expand Up @@ -541,7 +545,8 @@ csl.func @builtins() {
// CHECK-NEXT: const u32_value : u32 = 120;
// CHECK-NEXT: const f16_value : f16 = 7.0;
// CHECK-NEXT: const f32_value : f32 = 8.0;
// CHECK-NEXT: const col1 : color = @get_color(3);
// CHECK-NEXT: const three : i16 = 3;
// CHECK-NEXT: const col1 : color = @get_color(three);
// CHECK-NEXT: var f16_pointer : *f16 = &f16_value;
// CHECK-NEXT: var f32_pointer : *f32 = &f32_value;
// CHECK-NEXT: var i16_pointer : *i16 = &i16_value;
Expand Down Expand Up @@ -621,7 +626,8 @@ csl.func @builtins() {
// CHECK-NEXT: // -----
// CHECK-NEXT: // FILE: layout.csl
// CHECK-NEXT: param param_1 : i32;
// CHECK-NEXT: param param_2 : f16 = 1.3;
// CHECK-NEXT: const init : f16 = 3.14;
// CHECK-NEXT: param param_2 : f16 = init;
// CHECK-NEXT: layout {
// CHECK-NEXT: const x_dim : i32 = 4;
// CHECK-NEXT: const y_dim : i32 = 6;
Expand Down
12 changes: 8 additions & 4 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ csl.func @initialize() {
items = {i = 42 : i32, f = 3.7 : f32 }
}> : (i32, i16, !csl.color) -> !csl.comptime_struct

%col_1 = "csl.get_color"() <{id = 3 : i5}> : () -> !csl.color
%three = arith.constant 3 : i16
%col_1 = "csl.get_color"(%three) : (i16) -> !csl.color


%arr, %scalar, %tens = "test.op"() : () -> (memref<10xf32>, i32, tensor<510xf32>)
Expand Down Expand Up @@ -297,7 +298,8 @@ csl.func @builtins() {

"csl.module"() <{kind = #csl<module_kind layout>}> ({
%comp_const = "csl.param"() <{param_name = "comp_constant"}> : () -> i32
%comp_const_with_def = "csl.param"() <{param_name = "comp_constant", init_value = 1 : i32}> : () -> i32
%init = arith.constant 3.14 : f16
%p2 = "csl.param"(%init) <{param_name = "param_2"}> : (f16) -> f16
csl.layout {
%x_dim, %y_dim = "test.op"() : () -> (i32, i32)
"csl.set_rectangle"(%x_dim, %y_dim) : (i32, i32) -> ()
Expand Down Expand Up @@ -345,7 +347,8 @@ csl.func @builtins() {
// CHECK-NEXT: %attr_struct = "csl.const_struct"() <{"items" = {"i" = 42 : i32, "f" = 3.700000e+00 : f32}}> : () -> !csl.comptime_struct
// CHECK-NEXT: %ssa_struct = "csl.const_struct"(%arg1_1, %arg2_1, %col) <{"ssa_fields" = ["i32_", "i16_", "col"]}> : (i32, i16, !csl.color) -> !csl.comptime_struct
// CHECK-NEXT: %mixed_struct = "csl.const_struct"(%arg1_1, %arg2_1, %col) <{"ssa_fields" = ["i32_", "i16_", "col"], "items" = {"i" = 42 : i32, "f" = 3.700000e+00 : f32}}> : (i32, i16, !csl.color) -> !csl.comptime_struct
// CHECK-NEXT: %col_1 = "csl.get_color"() <{"id" = 3 : i5}> : () -> !csl.color
// CHECK-NEXT: %three = arith.constant 3 : i16
// CHECK-NEXT: %col_1 = "csl.get_color"(%three) : (i16) -> !csl.color
// CHECK-NEXT: %arr, %scalar, %tens = "test.op"() : () -> (memref<10xf32>, i32, tensor<510xf32>)
// CHECK-NEXT: %int8, %int16, %u16 = "test.op"() : () -> (si8, si16, ui16)
// CHECK-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr<i32, #csl<ptr_kind single>, #csl<ptr_const const>>
Expand Down Expand Up @@ -523,7 +526,8 @@ csl.func @builtins() {
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-NEXT: %comp_const = "csl.param"() <{"param_name" = "comp_constant"}> : () -> i32
// CHECK-NEXT: %comp_const_with_def = "csl.param"() <{"param_name" = "comp_constant", "init_value" = 1 : i32}> : () -> i32
// CHECK-NEXT: %init = arith.constant 3.140000e+00 : f16
// CHECK-NEXT: %p2 = "csl.param"(%init) <{"param_name" = "param_2"}> : (f16) -> f16
// CHECK-NEXT: csl.layout {
// CHECK-NEXT: x_dim, %y_dim = "test.op"() : () -> (i32, i32)
// CHECK-NEXT: "csl.set_rectangle"(%x_dim, %y_dim) : (i32, i32) -> ()
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def print_block(self, body: Block):
if init is None:
init = ""
else:
init = f" = { self.attribute_value_to_str(init)}"
init = f" = { self._get_variable_name_for(init)}"
ty = self.mlir_type_to_csl_type(res.type)
self.print(f"param {name.data} : {ty}{init};")
case csl.ConstStructOp(
Expand Down Expand Up @@ -482,7 +482,7 @@ def print_block(self, body: Block):
y = self._get_variable_name_for(y_dim)
self.print(f"@set_rectangle({x}, {y});")
case csl.GetColorOp(id=id, res=res):
id = self.attribute_value_to_str(id)
id = self._get_variable_name_for(id)
self.print(f"{self._var_use(res)} = @get_color({id});")
case csl.RpcOp(id=id):
id = self._get_variable_name_for(id)
Expand Down
29 changes: 10 additions & 19 deletions xdsl/dialects/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TypeAttribute,
)
from xdsl.irdl import (
ConstraintVar,
IRDLOperation,
ParameterDef,
ParametrizedAttribute,
Expand Down Expand Up @@ -331,19 +332,11 @@ class ColorType(ParametrizedAttribute, TypeAttribute):
name = "csl.color"


ColorIdAttr: TypeAlias = (
IntegerAttr[Annotated[IntegerType, IntegerType(5)]]
| IntegerAttr[Annotated[IntegerType, IntegerType(6)]]
)
ColorIdAttr: TypeAlias = IntegerAttr[IntegerType]

QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]

ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
# NOTE: Some of these values cannot be set by default, because we don't have
# corresponding attrinutes for them.
ParamType: TypeAlias = (
Float16Type | Float32Type | IntegerType | ColorType | FunctionType | StructLike
)


@irdl_op_definition
Expand Down Expand Up @@ -410,7 +403,7 @@ def verify_(self) -> None:
class GetColorOp(IRDLOperation):
name = "csl.get_color"

id = prop_def(ColorIdAttr)
id = operand_def(IntegerType)
res = result_def(ColorType)


Expand Down Expand Up @@ -1479,21 +1472,19 @@ class ParamOp(IRDLOperation):
command line by passing params to the compiler.
"""

T = Annotated[
Float16Type | Float32Type | IntegerType | ColorType | FunctionType | StructLike,
ConstraintVar("T"),
]

name = "csl.param"

traits = frozenset([HasParent(CslModuleOp)]) # has to be at top level

param_name = prop_def(StringAttr)
init_value = opt_prop_def(ParamAttr)

res = result_def(ParamType)
init_value = opt_operand_def(T)

def verify_(self) -> None:
if self.init_value is not None and self.init_value.type != self.res.type:
raise VerifyException(
"If init_value is specified, it has to have the same type as the op result"
)
return super().verify_()
res = result_def(T)


CSL = Dialect(
Expand Down

0 comments on commit c872084

Please sign in to comment.