diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index fe3c33384e..88b0a95617 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -74,6 +74,7 @@ SymbolTable, ) from xdsl.utils.exceptions import DiagnosticException, VerifyException +from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr if TYPE_CHECKING: @@ -1715,11 +1716,11 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems: VectorType[AttributeCovT] | TensorType[AttributeCovT] | MemRefType[AttributeCovT] ) +AnyDenseElement: TypeAlias = IntegerType | IndexType | AnyFloat + @irdl_attr_definition -class DenseIntOrFPElementsAttr( - ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat] -): +class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]): name = "dense" type: ParameterDef[ RankedStructure[IntegerType] @@ -1871,6 +1872,59 @@ def tensor_from_list( t = TensorType(data_type, shape) return DenseIntOrFPElementsAttr.from_list(t, data) + @staticmethod + def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute: + assert isa(type, RankedStructure[AnyDenseElement]) + return parser.parse_dense_int_or_fp_elements_attr(type) + + @staticmethod + def _print_one_elem(val: Attribute, printer: Printer): + if isinstance(val, IntegerAttr): + printer.print_string(f"{val.value.data}") + elif isinstance(val, FloatAttr): + printer.print_float(cast(AnyFloatAttr, val)) + else: + raise Exception( + "unexpected attribute type " + "in DenseIntOrFPElementsAttr: " + f"{type(val)}" + ) + + @staticmethod + def _print_dense_list( + array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], + shape: Sequence[int], + printer: Printer, + ): + printer.print_string("[") + if len(shape) > 1: + k = len(array) // shape[0] + printer.print_list( + (array[i : i + k] for i in range(0, len(array), k)), + lambda subarray: DenseIntOrFPElementsAttr._print_dense_list( + subarray, shape[1:], printer + ), + ) + else: + printer.print_list( + array, + lambda val: DenseIntOrFPElementsAttr._print_one_elem(val, printer), + ) + printer.print_string("]") + + def print_without_type(self, printer: Printer): + printer.print_string("dense<") + data = self.data.data + shape = self.get_shape() if self.shape_is_complete else (len(data),) + assert shape is not None, "If shape is complete, then it cannot be None" + if len(data) == 0: + pass + elif data.count(data[0]) == len(data): + DenseIntOrFPElementsAttr._print_one_elem(data[0], printer) + else: + DenseIntOrFPElementsAttr._print_dense_list(data, shape, printer) + printer.print_string(">") + Builtin = Dialect( "builtin", diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py index 45cf56d2d6..031ba150b8 100644 --- a/xdsl/parser/attribute_parser.py +++ b/xdsl/parser/attribute_parser.py @@ -14,6 +14,7 @@ AffineMapAttr, AffineSetAttr, AnyArrayAttr, + AnyDenseElement, AnyFloat, AnyFloatAttr, AnyFloatConstr, @@ -698,11 +699,7 @@ def _parse_optional_builtin_parametrized_attr(self) -> Attribute | None: def _parse_builtin_dense_attr_hex( self, hex_string: str, - type: ( - RankedStructure[IntegerType] - | RankedStructure[IndexType] - | RankedStructure[AnyFloat] - ), + type: RankedStructure[AnyDenseElement], ) -> tuple[list[int] | list[float], list[int]]: """ Parse a hex string literal e.g. dense<"0x82F5AB00">, and returns its flattened data @@ -795,7 +792,9 @@ def _parse_dense_literal_type( self.raise_error("Dense literal attribute should have a static shape.") return type - def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr: + def parse_dense_int_or_fp_elements_attr( + self, type: RankedStructure[AnyDenseElement] | None + ) -> DenseIntOrFPElementsAttr: dense_contents: ( tuple[list[AttrParser._TensorLiteralElement], list[int]] | str | None ) @@ -821,8 +820,9 @@ def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr: self.parse_punctuation(">", " in dense attribute") # Parse the dense type and check for correctness - self.parse_punctuation(":", " in dense attribute") - type = self._parse_dense_literal_type() + if type is None: + self.parse_punctuation(":", " in dense attribute") + type = self._parse_dense_literal_type() type_shape = list(type.get_shape()) type_num_values = math.prod(type_shape) @@ -866,6 +866,9 @@ def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr: return DenseIntOrFPElementsAttr.from_list(type, data_values) + def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr: + return self.parse_dense_int_or_fp_elements_attr(None) + def _parse_builtin_opaque_attr(self, _name: Span): str_lit_list = self.parse_comma_separated_list( self.Delimiter.ANGLE, self.parse_str_literal diff --git a/xdsl/printer.py b/xdsl/printer.py index 826ed3d05e..2468a96800 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -14,7 +14,6 @@ AffineMapAttr, AffineSetAttr, AnyFloatAttr, - AnyIntegerAttr, AnyUnrankedMemrefType, AnyUnrankedTensorType, AnyVectorType, @@ -23,7 +22,6 @@ BytesAttr, ComplexType, DenseArrayBase, - DenseIntOrFPElementsAttr, DenseResourceAttr, DictionaryAttr, Float16Type, @@ -31,7 +29,6 @@ Float64Type, Float80Type, Float128Type, - FloatAttr, FloatData, FunctionType, IndexType, @@ -612,51 +609,6 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_string(")") return - if isinstance(attribute, DenseIntOrFPElementsAttr): - - def print_one_elem(val: Attribute): - if isinstance(val, IntegerAttr): - self.print_string(f"{val.value.data}") - elif isinstance(val, FloatAttr): - self.print_float(cast(AnyFloatAttr, val)) - else: - raise Exception( - "unexpected attribute type " - "in DenseIntOrFPElementsAttr: " - f"{type(val)}" - ) - - def print_dense_list( - array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], - shape: Sequence[int], - ): - self.print_string("[") - if len(shape) > 1: - k = len(array) // shape[0] - self.print_list( - (array[i : i + k] for i in range(0, len(array), k)), - lambda subarray: print_dense_list(subarray, shape[1:]), - ) - else: - self.print_list(array, print_one_elem) - self.print_string("]") - - self.print_string("dense<") - data = attribute.data.data - shape = ( - attribute.get_shape() if attribute.shape_is_complete else (len(data),) - ) - assert shape is not None, "If shape is complete, then it cannot be None" - if len(data) == 0: - pass - elif data.count(data[0]) == len(data): - print_one_elem(data[0]) - else: - print_dense_list(data, shape) - self.print_string("> : ") - self.print_attribute(attribute.type) - return - if isinstance(attribute, DenseResourceAttr): handle = attribute.resource_handle.data self.print_string(f"dense_resource<{handle}> : ")