From c87208422f18b2edab6e25ff8cfe8b0e7ea2b7d6 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Wed, 12 Jun 2024 17:59:31 +0100 Subject: [PATCH] dialects: (csl) Small incremental fixes to dialect structure (stacked PR) (#2711) - Make `@get_color` take an operand instead of property. - Make `ParamOp` take a default value as operand instead of property. --- tests/filecheck/backend/csl/print_csl.mlir | 18 +++++++++----- tests/filecheck/dialects/csl/ops.mlir | 12 ++++++--- xdsl/backend/csl/print_csl.py | 4 +-- xdsl/dialects/csl.py | 29 ++++++++-------------- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/filecheck/backend/csl/print_csl.mlir b/tests/filecheck/backend/csl/print_csl.mlir index f77c842570..f8449a06c9 100644 --- a/tests/filecheck/backend/csl/print_csl.mlir +++ b/tests/filecheck/backend/csl/print_csl.mlir @@ -111,7 +111,8 @@ "csl.export"() <{type = (i32, i32) -> (), var_name = @args_no_return}> : () -> () - %col = "csl.get_color"() <{id = 15 : i5}> : () -> !csl.color + %cst15 = arith.constant 15 : i32 + %col = "csl.get_color"(%cst15) : (i32) -> !csl.color "csl.rpc"(%col) : (!csl.color) -> () @@ -249,7 +250,8 @@ csl.func @builtins() { %u32_value = arith.constant 120 : ui32 %f16_value = arith.constant 7.0 : f16 %f32_value = arith.constant 8.0 : f32 - %col_1 = "csl.get_color"() <{id = 3 : i5}> : () -> !csl.color + %three = arith.constant 3 : i16 + %col_1 = "csl.get_color"(%three) : (i16) -> !csl.color %f16_pointer = "csl.addressof"(%f16_value) : (f16) -> !csl.ptr, #csl> %f32_pointer = "csl.addressof"(%f32_value) : (f32) -> !csl.ptr, #csl> %i16_pointer = "csl.addressof"(%i16_value) : (si16) -> !csl.ptr, #csl> @@ -320,7 +322,8 @@ csl.func @builtins() { "csl.module"() <{kind=#csl}> ({ %p1 = "csl.param"() <{param_name = "param_1"}> : () -> i32 - %p2 = "csl.param"() <{param_name = "param_2", init_value = 1.3 : f16}> : () -> f16 + %init = arith.constant 3.14 : f16 + %p2 = "csl.param"(%init) <{param_name = "param_2"}> : (f16) -> f16 csl.layout { %x_dim = arith.constant 4 : i32 @@ -443,7 +446,8 @@ csl.func @builtins() { // CHECK-NEXT: comptime { // CHECK-NEXT: @export_symbol(args_no_return, "args_no_return"); // CHECK-NEXT: } -// CHECK-NEXT: const col : color = @get_color(15); +// CHECK-NEXT: const cst15 : i32 = 15; +// CHECK-NEXT: const col : color = @get_color(cst15); // CHECK-NEXT: comptime { // CHECK-NEXT: @rpc(@get_data_task_id(col)); // CHECK-NEXT: } @@ -541,7 +545,8 @@ csl.func @builtins() { // CHECK-NEXT: const u32_value : u32 = 120; // CHECK-NEXT: const f16_value : f16 = 7.0; // CHECK-NEXT: const f32_value : f32 = 8.0; -// CHECK-NEXT: const col1 : color = @get_color(3); +// CHECK-NEXT: const three : i16 = 3; +// CHECK-NEXT: const col1 : color = @get_color(three); // CHECK-NEXT: var f16_pointer : *f16 = &f16_value; // CHECK-NEXT: var f32_pointer : *f32 = &f32_value; // CHECK-NEXT: var i16_pointer : *i16 = &i16_value; @@ -621,7 +626,8 @@ csl.func @builtins() { // CHECK-NEXT: // ----- // CHECK-NEXT: // FILE: layout.csl // CHECK-NEXT: param param_1 : i32; -// CHECK-NEXT: param param_2 : f16 = 1.3; +// CHECK-NEXT: const init : f16 = 3.14; +// CHECK-NEXT: param param_2 : f16 = init; // CHECK-NEXT: layout { // CHECK-NEXT: const x_dim : i32 = 4; // CHECK-NEXT: const y_dim : i32 = 6; diff --git a/tests/filecheck/dialects/csl/ops.mlir b/tests/filecheck/dialects/csl/ops.mlir index a0027ac283..779df26ecb 100644 --- a/tests/filecheck/dialects/csl/ops.mlir +++ b/tests/filecheck/dialects/csl/ops.mlir @@ -61,7 +61,8 @@ csl.func @initialize() { items = {i = 42 : i32, f = 3.7 : f32 } }> : (i32, i16, !csl.color) -> !csl.comptime_struct - %col_1 = "csl.get_color"() <{id = 3 : i5}> : () -> !csl.color + %three = arith.constant 3 : i16 + %col_1 = "csl.get_color"(%three) : (i16) -> !csl.color %arr, %scalar, %tens = "test.op"() : () -> (memref<10xf32>, i32, tensor<510xf32>) @@ -297,7 +298,8 @@ csl.func @builtins() { "csl.module"() <{kind = #csl}> ({ %comp_const = "csl.param"() <{param_name = "comp_constant"}> : () -> i32 - %comp_const_with_def = "csl.param"() <{param_name = "comp_constant", init_value = 1 : i32}> : () -> i32 + %init = arith.constant 3.14 : f16 + %p2 = "csl.param"(%init) <{param_name = "param_2"}> : (f16) -> f16 csl.layout { %x_dim, %y_dim = "test.op"() : () -> (i32, i32) "csl.set_rectangle"(%x_dim, %y_dim) : (i32, i32) -> () @@ -345,7 +347,8 @@ csl.func @builtins() { // CHECK-NEXT: %attr_struct = "csl.const_struct"() <{"items" = {"i" = 42 : i32, "f" = 3.700000e+00 : f32}}> : () -> !csl.comptime_struct // CHECK-NEXT: %ssa_struct = "csl.const_struct"(%arg1_1, %arg2_1, %col) <{"ssa_fields" = ["i32_", "i16_", "col"]}> : (i32, i16, !csl.color) -> !csl.comptime_struct // CHECK-NEXT: %mixed_struct = "csl.const_struct"(%arg1_1, %arg2_1, %col) <{"ssa_fields" = ["i32_", "i16_", "col"], "items" = {"i" = 42 : i32, "f" = 3.700000e+00 : f32}}> : (i32, i16, !csl.color) -> !csl.comptime_struct -// CHECK-NEXT: %col_1 = "csl.get_color"() <{"id" = 3 : i5}> : () -> !csl.color +// CHECK-NEXT: %three = arith.constant 3 : i16 +// CHECK-NEXT: %col_1 = "csl.get_color"(%three) : (i16) -> !csl.color // CHECK-NEXT: %arr, %scalar, %tens = "test.op"() : () -> (memref<10xf32>, i32, tensor<510xf32>) // CHECK-NEXT: %int8, %int16, %u16 = "test.op"() : () -> (si8, si16, ui16) // CHECK-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr, #csl> @@ -523,7 +526,8 @@ csl.func @builtins() { // CHECK-NEXT: }) {"sym_name" = "program"} : () -> () // CHECK-NEXT: "csl.module"() <{"kind" = #csl}> ({ // CHECK-NEXT: %comp_const = "csl.param"() <{"param_name" = "comp_constant"}> : () -> i32 -// CHECK-NEXT: %comp_const_with_def = "csl.param"() <{"param_name" = "comp_constant", "init_value" = 1 : i32}> : () -> i32 +// CHECK-NEXT: %init = arith.constant 3.140000e+00 : f16 +// CHECK-NEXT: %p2 = "csl.param"(%init) <{"param_name" = "param_2"}> : (f16) -> f16 // CHECK-NEXT: csl.layout { // CHECK-NEXT: x_dim, %y_dim = "test.op"() : () -> (i32, i32) // CHECK-NEXT: "csl.set_rectangle"(%x_dim, %y_dim) : (i32, i32) -> () diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index 498bb31395..68c3b6994c 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -451,7 +451,7 @@ def print_block(self, body: Block): if init is None: init = "" else: - init = f" = { self.attribute_value_to_str(init)}" + init = f" = { self._get_variable_name_for(init)}" ty = self.mlir_type_to_csl_type(res.type) self.print(f"param {name.data} : {ty}{init};") case csl.ConstStructOp( @@ -482,7 +482,7 @@ def print_block(self, body: Block): y = self._get_variable_name_for(y_dim) self.print(f"@set_rectangle({x}, {y});") case csl.GetColorOp(id=id, res=res): - id = self.attribute_value_to_str(id) + id = self._get_variable_name_for(id) self.print(f"{self._var_use(res)} = @get_color({id});") case csl.RpcOp(id=id): id = self._get_variable_name_for(id) diff --git a/xdsl/dialects/csl.py b/xdsl/dialects/csl.py index 3cd3c35d1f..f9ea2df4bc 100644 --- a/xdsl/dialects/csl.py +++ b/xdsl/dialects/csl.py @@ -46,6 +46,7 @@ TypeAttribute, ) from xdsl.irdl import ( + ConstraintVar, IRDLOperation, ParameterDef, ParametrizedAttribute, @@ -331,19 +332,11 @@ class ColorType(ParametrizedAttribute, TypeAttribute): name = "csl.color" -ColorIdAttr: TypeAlias = ( - IntegerAttr[Annotated[IntegerType, IntegerType(5)]] - | IntegerAttr[Annotated[IntegerType, IntegerType(6)]] -) +ColorIdAttr: TypeAlias = IntegerAttr[IntegerType] QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]] ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr -# NOTE: Some of these values cannot be set by default, because we don't have -# corresponding attrinutes for them. -ParamType: TypeAlias = ( - Float16Type | Float32Type | IntegerType | ColorType | FunctionType | StructLike -) @irdl_op_definition @@ -410,7 +403,7 @@ def verify_(self) -> None: class GetColorOp(IRDLOperation): name = "csl.get_color" - id = prop_def(ColorIdAttr) + id = operand_def(IntegerType) res = result_def(ColorType) @@ -1479,21 +1472,19 @@ class ParamOp(IRDLOperation): command line by passing params to the compiler. """ + T = Annotated[ + Float16Type | Float32Type | IntegerType | ColorType | FunctionType | StructLike, + ConstraintVar("T"), + ] + name = "csl.param" traits = frozenset([HasParent(CslModuleOp)]) # has to be at top level param_name = prop_def(StringAttr) - init_value = opt_prop_def(ParamAttr) - - res = result_def(ParamType) + init_value = opt_operand_def(T) - def verify_(self) -> None: - if self.init_value is not None and self.init_value.type != self.res.type: - raise VerifyException( - "If init_value is specified, it has to have the same type as the op result" - ) - return super().verify_() + res = result_def(T) CSL = Dialect(