Skip to content

Commit

Permalink
Add Ch.5 of the toy tutorial.
Browse files Browse the repository at this point in the history
This chapter adds a partial lowering of toy operations, all but PrintOp, to a combination of the Affine and Std dialects. This chapter focuses on introducing the conversion framework, the benefits of partial lowering, and how easily dialects may co-exist in the IR.

PiperOrigin-RevId: 275150649
  • Loading branch information
River707 authored and tensorflower-gardener committed Oct 17, 2019
1 parent 48f7ec8 commit 288ac1b
Show file tree
Hide file tree
Showing 36 changed files with 1,956 additions and 2,681 deletions.
9 changes: 7 additions & 2 deletions examples/toy/Ch3/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,27 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
ConstantOp::build(builder, state, dataType, dataAttribute);
}

/// Verifier for constant operation.
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
if (!resultType)
return success();

// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
"return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {

// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")
Expand Down
6 changes: 5 additions & 1 deletion examples/toy/Ch4/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,18 @@ static mlir::LogicalResult verify(ConstantOp op) {
if (!resultType)
return success();

// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
"return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {

// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")
Expand Down
9 changes: 5 additions & 4 deletions examples/toy/Ch4/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MLIRGenImpl {
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("Module verification error");
theModule.emitError("module verification error");
return nullptr;
}

Expand Down Expand Up @@ -229,7 +229,7 @@ class MLIRGenImpl {
if (auto *variable = symbolTable.lookup(expr.getName()))
return variable;

emitError(loc(expr.loc()), "Error: unknown variable '")
emitError(loc(expr.loc()), "error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
Expand Down Expand Up @@ -289,7 +289,8 @@ class MLIRGenImpl {
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));

// Build the MLIR op `toy.constant`.
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
// method.
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}

Expand Down Expand Up @@ -389,7 +390,7 @@ class MLIRGenImpl {
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"Missing initializer in variable declaration");
"missing initializer in variable declaration");
return nullptr;
}

Expand Down
40 changes: 21 additions & 19 deletions examples/toy/Ch5/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
add_subdirectory(include)

set(LLVM_LINK_COMPONENTS
Support
)

set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
add_public_tablegen_target(ToyCh5CombineIncGen)

add_toy_chapter(toyc-ch5
toyc.cpp
parser/AST.cpp
mlir/EarlyLowering.cpp
mlir/LateLowering.cpp
mlir/MLIRGen.cpp
mlir/Dialect.cpp
mlir/DeadFunctionEliminationPass.cpp
mlir/LowerToAffineLoops.cpp
mlir/ShapeInferencePass.cpp
mlir/ToyDialect.cpp
mlir/ToyCombine.cpp
)

add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen)
add_dependencies(toyc-ch5 ToyCh5OpsIncGen)
add_dependencies(toyc-ch5 ToyCh5CombineIncGen)
add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen)
include_directories(include/)
include_directories(../../Linalg/Linalg1/include/)
include_directories(../../Linalg/Linalg2/include/)
include_directories(../../Linalg/Linalg3/include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch5
PRIVATE
Linalg3DialectConstruction
Linalg3
Linalg2
Linalg1
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRParser
MLIRPass
MLIRTargetLLVMIR
MLIRTransforms
MLIRSupport
)
MLIRStandardOps
MLIRTransforms)

whole_archive_link(toyc-ch5
MLIRAffineOps
MLIRStandardOps
)

)
1 change: 1 addition & 0 deletions examples/toy/Ch5/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(toy)
15 changes: 6 additions & 9 deletions examples/toy/Ch5/include/toy/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@

namespace toy {

/// A variable
/// A variable type with shape information.
struct VarType {
enum { TY_FLOAT, TY_INT } elt_ty;
std::vector<int> shape;
std::vector<int64_t> shape;
};

/// Base class for all expression nodes.
Expand All @@ -50,9 +49,7 @@ class ExprAST {
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print, // builtin
Expr_If,
Expr_For,
Expr_Print,
};

ExprAST(ExprASTKind kind, Location location)
Expand Down Expand Up @@ -85,7 +82,7 @@ class NumberExprAST : public ExprAST {
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};

///
/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
Expand Down Expand Up @@ -116,7 +113,7 @@ class VariableExprAST : public ExprAST {
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};

///
/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
Expand All @@ -136,7 +133,7 @@ class VarDeclExprAST : public ExprAST {
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};

///
/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;

Expand Down
9 changes: 9 additions & 0 deletions examples/toy/Ch5/include/toy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
add_public_tablegen_target(ToyCh5OpsIncGen)

set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen)
Loading

0 comments on commit 288ac1b

Please sign in to comment.