Skip to content

Commit

Permalink
[xls][mlir] Add missing MLIR dialect ops
Browse files Browse the repository at this point in the history
Adds the following XLS-supported ops to the MLIR dialect:

- nand, nor
- and_reduce, or_reduce, xor_reduce
- umulp, smulp
- gate

Note that for the partial product operations, operands of different bit
widths and operands of different bit width than the result are
explicitly allowed.
  • Loading branch information
schilkp committed Jan 15, 2025
1 parent 5f84820 commit f728967
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 8 deletions.
158 changes: 156 additions & 2 deletions xls/contrib/mlir/IR/xls_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ def Xls_AndOp : Xls_NaryOp<"and", [Commutative, Pure, SameOperandsAndResultType]
);
}

def Xls_NandOp : Xls_NaryOp<"nand", [Commutative, Pure, SameOperandsAndResultType]> {
let summary = "Bitwise NAND operation";
let description = [{
result = ~(operand_0 & ... & operand_{N-1})
}];
let arguments = (ins
Variadic<Xls_Bits>:$inputs
);
let results = (outs
Xls_Bits:$result
);
}

def Xls_OrOp : Xls_NaryOp<"or", [Commutative, Pure, SameOperandsAndResultType]> {
let summary = "Bitwise OR operation";
let description = [{
Expand All @@ -252,6 +265,20 @@ def Xls_OrOp : Xls_NaryOp<"or", [Commutative, Pure, SameOperandsAndResultType]>
);
}


def Xls_NorOp : Xls_NaryOp<"nor", [Commutative, Pure, SameOperandsAndResultType]> {
let summary = "Bitwise NOR operation";
let description = [{
result = ~(operand_0 | ... | operand_{N-1})
}];
let arguments = (ins
Variadic<Xls_Bits>:$inputs
);
let results = (outs
Xls_Bits:$result
);
}

def Xls_XorOp : Xls_NaryOp<"xor", [Commutative, Pure, SameOperandsAndResultType]> {
let summary = "Bitwise XOR operation";
let description = [{
Expand All @@ -265,6 +292,48 @@ def Xls_XorOp : Xls_NaryOp<"xor", [Commutative, Pure, SameOperandsAndResultType]
);
}

//===----------------------------------------------------------------------===//
// Bitwise reduction operations
//===----------------------------------------------------------------------===//

class Xls_ReductionOp<string name, list<Trait> traits = []> :
Xls_Op<name, !listconcat(traits, [
ShapesAreConsistent<["input", "result"]>,
Elementwise,
Scalarizable,
])> {
let arguments = (ins
Xls_Bits:$input
);
let results = (outs
Xls_Bits:$result
);
let assemblyFormat = [{
$input attr-dict `:` functional-type($input, $result)
}];
}

def Xls_AndReductionOp : Xls_ReductionOp<"and_reduce", [Commutative, Pure]> {
let summary = "Unary AND reduction operation";
let description = [{
result = operand[0] & operand[1] & ...
}];
}

def Xls_OrReductionOp : Xls_ReductionOp<"or_reduce", [Commutative, Pure]> {
let summary = "Unary OR reduction operation";
let description = [{
result = operand[0] | operand[1] | ...
}];
}

def Xls_XorReductionOp : Xls_ReductionOp<"xor_reduce", [Commutative, Pure]> {
let summary = "Unary XOR reduction operation";
let description = [{
result = operand[0] ^ operand[1] ^ ...
}];
}

//===----------------------------------------------------------------------===//
// Arithmetic unary operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -305,22 +374,65 @@ class Xls_ArithBinaryOp<string name, string desc>
);
}

class Xls_PartialProdOp<string name, string desc> : Xls_Op<name, [
Pure,
ShapesAreConsistent<["lhs", "rhs", "result_lhs", "result_rhs"]>,
Elementwise,
Scalarizable,
AllTypesMatch<["result_lhs", "result_rhs"]>
]> {
let summary = !strconcat(desc, " operation");
let description = summary;
let arguments = (ins
Xls_Bits:$lhs,
Xls_Bits:$rhs
);
let results = (outs
Xls_Bits:$result_lhs,
Xls_Bits:$result_rhs
);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` functional-type(operands, results)
}];
let builders = [
OpBuilder<(ins "::mlir::Type":$result_type, "::mlir::Value":$lhs, "::mlir::Value":$rhs), [{
build($_builder, $_state, result_type, result_type, lhs, rhs);
}]>,
];
}

def Xls_AddOp : Xls_ArithBinaryOp<"add", "addition"> {
let hasFolder = 1;
}
def Xls_SmulOp : Xls_ArithBinaryOp<"smul", "signed multiplication">;
def Xls_SmulpOp : Xls_PartialProdOp<"smulp", [{
Signed partial product multiply.

Returns two results such that:
result_lhs + result_rhs = lhs * rhs

The two results have the same type, but are not fully constrained: Any
two values that fulfill the above are valid.
}]>;
def Xls_UmulOp : Xls_ArithBinaryOp<"umul", "unsigned multiplication"> {
let hasFolder = 1;
let hasCanonicalizeMethod = 1;
}
def Xls_UmulpOp : Xls_PartialProdOp<"umulp", [{
Unsigned partial product multiply.

Returns two results such that:
result_lhs + result_rhs = lhs * rhs

The two results have the same type, but are not fully constrained: Any
two values that fulfill the above are valid.
}]>;
def Xls_SdivOp : Xls_ArithBinaryOp<"sdiv", "signed division">;
def Xls_SmodOp : Xls_ArithBinaryOp<"smod", "signed modulo">;
def Xls_SubOp : Xls_ArithBinaryOp<"sub", "subtraction">;
def Xls_UdivOp : Xls_ArithBinaryOp<"udiv", "unsigned division">;
def Xls_UmodOp : Xls_ArithBinaryOp<"umod", "unsigned modulo">;

// TODO(jmolloy): Smulp/Umulp - partial product multiplication.

//===----------------------------------------------------------------------===//
// Comparison operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1881,6 +1993,48 @@ def Xls_TraceOp : Xls_Op<"trace", [TensorArrayTypeFungible,
];
}

//===----------------------------------------------------------------------===//
// Other side-effecting operations
//===----------------------------------------------------------------------===//

def Xls_GateOp :
Xls_Op<"gate", [
ShapesAreConsistent<["data", "result"]>,
Elementwise,
Scalarizable,
]> {
let summary = "Gates an arbitrarily-typed value based on a condition";
let description = [{
The result of the operation is the data operand if the condition is true,
otherwise the result is a zero value of the type of the data operand
(i.e., the value is gated off).

A helpful mnemonic is to think of this as analogous to an AND gate: if the
condition is true, the value passes through, otherwise it's zeroed.

This operation intended for use in operand gating for power reduction.

The operation is considered side-effecting to prevent removal of the operation
when the gated result (condition is false) is not observable. The 'side-effect'
of this operation is the effect it can have on power consumption.
}];
let arguments = (ins
I1:$condition,
Xls_BitsOrTuple:$data
);
let results = (outs
Xls_BitsOrTuple:$result
);
let assemblyFormat = [{
$condition `,` $data attr-dict `:` functional-type(operands, results)
}];
let builders = [
OpBuilder<(ins "::mlir::Value":$condition, "::mlir::Value":$data), [{
build($_builder, $_state, data.getType(), condition, data);
}]>
];
}

//===----------------------------------------------------------------------===//
// Lowering internal operations
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 40 additions & 0 deletions xls/contrib/mlir/testdata/ops_translate.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,41 @@ func.func @and(%arg0: i8, %arg1: i8) -> i8 {
return %0 : i8
}

func.func @nand(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.nand %arg0, %arg1 : i8
return %0 : i8
}

func.func @or(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.or %arg0, %arg1 : i8
return %0 : i8
}

func.func @nor(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.nor %arg0, %arg1 : i8
return %0 : i8
}

func.func @xor(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.xor %arg0, %arg1 : i8
return %0 : i8
}

func.func @and_reduce(%arg0: i8) -> i1 {
%0 = xls.and_reduce %arg0 : (i8) ->i1
return %0 : i1
}

func.func @or_reduce(%arg0: i8) -> i1 {
%0 = xls.or_reduce %arg0 : (i8) ->i1
return %0 : i1
}

func.func @xor_reduce(%arg0: i8) -> i1 {
%0 = xls.xor_reduce %arg0 : (i8) ->i1
return %0 : i1
}

func.func @neg(%arg0: i8) -> i8 {
%0 = xls.neg %arg0 : i8
return %0 : i8
Expand All @@ -42,11 +67,21 @@ func.func @smul(%arg0: i8, %arg1: i8) -> i8 {
return %0 : i8
}

func.func @smulp(%arg0: i8, %arg1: i7) -> i9 {
%0, %1 = xls.smulp %arg0, %arg1 : (i8, i7) -> (i9, i9)
return %0 : i9
}

func.func @umul(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.umul %arg0, %arg1 : i8
return %0 : i8
}

func.func @umulp(%arg0: i8, %arg1: i7) -> i9 {
%0, %1 = xls.umulp %arg0, %arg1 : (i8, i7) -> (i9, i9)
return %0 : i9
}

func.func @sdiv(%arg0: i8, %arg1: i8) -> i8 {
%0 = xls.sdiv %arg0, %arg1 : i8
return %0 : i8
Expand Down Expand Up @@ -263,6 +298,11 @@ func.func @bitcast(%arg0: f32) -> i32 {
return %0 : i32
}

func.func @gate(%arg0: i32, %condition: i1) -> i32 {
%0 = xls.gate %condition, %arg0 : (i1, i32) -> i32
return %0 : i32
}

// TODO
// func.func @constant_tensor() -> tensor<3xi8> {
// %0 = "xls.constant_tensor"() { value = dense<[0, 1, 2]> : tensor<3xi8> } : () -> tensor<3xi8>
Expand Down
44 changes: 38 additions & 6 deletions xls/contrib/mlir/tools/xls_translate/xls_translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,16 @@ XLS_UNARY_OP(NotOp, Not);
return fb.BUILDER(args, state.getLoc(op)); \
}
XLS_VARIADIC_BINARY_OP(AndOp, And);
XLS_VARIADIC_BINARY_OP(NandOp, Nand);
XLS_VARIADIC_BINARY_OP(OrOp, Or);
XLS_VARIADIC_BINARY_OP(NorOp, Nor);
XLS_VARIADIC_BINARY_OP(XorOp, Xor);

// Bitwise reduction operations
XLS_UNARY_OP(AndReductionOp, AndReduce);
XLS_UNARY_OP(OrReductionOp, OrReduce);
XLS_UNARY_OP(XorReductionOp, XorReduce);

// Arithmetic unary operations
XLS_UNARY_OP(NegOp, Negate);

Expand All @@ -324,6 +331,21 @@ XLS_BINARY_OP(UdivOp, UDiv);
XLS_BINARY_OP(UmodOp, UMod);
XLS_BINARY_OP(UmulOp, UMul);

// Partial Products
#define XLS_PARTIAL_PROD_OP(TYPE, BUILDER) \
BValue convertOp(TYPE op, TranslationState& state, BuilderBase& fb) { \
auto element_type = \
cast<IntegerType>(mlir::getElementTypeOrSelf(op.getResultLhs())); \
\
BValue out = fb.BUILDER(state.getXlsValue(op.getLhs()), \
state.getXlsValue(op.getRhs()), \
element_type.getWidth(), state.getLoc(op)); \
state.setMultiResultValue(op, out, fb); \
return out; \
}
XLS_PARTIAL_PROD_OP(SmulpOp, SMulp);
XLS_PARTIAL_PROD_OP(UmulpOp, UMulp);

// Comparison operations
XLS_BINARY_OP(EqOp, Eq);
XLS_BINARY_OP(NeOp, Ne);
Expand Down Expand Up @@ -861,6 +883,11 @@ BValue convertOp(NonblockingReceiveOp op, TranslationState& state,
return out;
}

BValue convertOp(GateOp op, TranslationState& state, BuilderBase& fb) {
return fb.Gate(state.getXlsValue(op.getCondition()),
state.getXlsValue(op.getData()), state.getLoc(op));
}

FailureOr<PackageInfo> importDslxInstantiation(
ImportDslxFilePackageOp file_import_op, llvm::StringRef dslx_snippet,
Package& package) {
Expand Down Expand Up @@ -1027,11 +1054,14 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
// Unary bitwise ops.
IdentityOp, NotOp,
// Variadic bitwise operations
AndOp, OrOp, XorOp,
AndOp, NandOp, OrOp, NorOp, XorOp,
// Bitwise reduction operations
AndReductionOp, OrReductionOp, XorReductionOp,
// Arithmetic unary operations
NegOp,
// Binary ops.
AddOp, SdivOp, SmodOp, SmulOp, SubOp, UdivOp, UmodOp, UmulOp,
AddOp, SdivOp, SmodOp, SmulOp, SmulpOp, SubOp, UdivOp, UmodOp,
UmulOp, UmulpOp,
// Comparison operations
EqOp, NeOp, SgeOp, SgtOp, SleOp, SltOp, UgeOp, UgtOp, UleOp, UltOp,
// Shift operations
Expand All @@ -1057,8 +1087,9 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
// CSP ops
AfterAllOp, SendOp, BlockingReceiveOp, NonblockingReceiveOp,
// Debugging ops
TraceOp>(
[&](auto t) { return convertOp(t, translation_state, fb); })
TraceOp,
// Misc. side-effecting ops
GateOp>([&](auto t) { return convertOp(t, translation_state, fb); })
.Case<mlir::func::ReturnOp, YieldOp>([&](auto ret) {
if (ret.getNumOperands() == 1) {
return out = value_map[ret.getOperand(0)];
Expand Down Expand Up @@ -1091,8 +1122,9 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
if (op == xls_region) {
return WalkResult::skip();
}
// Receives have multiple results but are explicitly supported.
if (!isa<BlockingReceiveOp, NonblockingReceiveOp>(op)) {
// Receives and partial products have multiple results but are explicitly
// supported.
if (!isa<BlockingReceiveOp, NonblockingReceiveOp, UmulpOp, SmulpOp>(op)) {
assert(op->getNumResults() <= 1 && "Multiple results not supported");
}

Expand Down

0 comments on commit f728967

Please sign in to comment.