diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 2f49e0a9cffa..eafeaac1c5d5 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -14,6 +14,7 @@ include "SQLDialect.td" include "SQLTypes.td" + class SQL_Op traits = []> : Op; @@ -61,7 +62,7 @@ def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { def AndOp: SQL_Op<"and", [Pure]> { let summary = "and op"; - let arguments = (ins Variadic:$expr); + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); let results = (outs SQLBoolType:$result); let hasFolder = 0; @@ -71,7 +72,7 @@ def AndOp: SQL_Op<"and", [Pure]> { def OrOp: SQL_Op<"or", [Pure]> { let summary = "or op"; - let arguments = (ins Variadic:$expr); + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); let results = (outs SQLBoolType:$result); let hasFolder = 0; @@ -110,6 +111,17 @@ def SQLToStringOp : SQL_Op<"to_string", [Pure]> { } +def SQLBoolToStringOp : SQL_Op<"bool_to_string", [Pure]> { + let summary = "bool_to_string"; + + let arguments = (ins SQLBoolType:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { let summary = "string_concat"; @@ -123,7 +135,6 @@ def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> { let summary = "constant_bool"; let results = (outs SQLBoolType:$result); - } @@ -132,9 +143,8 @@ def SelectOp : SQL_Op<"select", [Pure]> { // i need to specify the size of a Variadic? let arguments = (ins Variadic:$columns, SQLExprType:$table, - // SQLBoolType:$where, - BoolAttr:$selectAll - IntAttr:$limit); + SQLExprType:$where, + SI64Attr:$limit); // attribute limit if >= 0 then its the real thing, otherwise its infinity let results = (outs SQLExprType:$result); diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index bfc4c28200c1..4f2c2b34cddb 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -5,10 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include +#include #include #include -#include -#include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -18,11 +18,11 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "polygeist/Ops.h" +#include "sql/Parser.h" #include "sql/SQLDialect.h" #include "sql/SQLOps.h" #include "sql/SQLTypes.h" -#include "sql/Parser.h" -#include "polygeist/Ops.h" #define GET_OP_CLASSES #include "sql/SQLOps.cpp.inc" @@ -41,13 +41,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" - -#include "mlir/IR/Value.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinTypes.h" #define DEBUG_TYPE "sql" @@ -55,7 +54,6 @@ using namespace mlir; using namespace sql; using namespace mlir::arith; - class GetValueOpTypeFix final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -67,38 +65,38 @@ class GetValueOpTypeFix final : public OpRewritePattern { Value handle = op.getOperand(0); if (!handle.getType().isa()) { - handle = rewriter.create(op.getLoc(), - rewriter.getIndexType(), handle); - changed = true; + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; } Value row = op.getOperand(1); if (!row.getType().isa()) { - row = rewriter.create(op.getLoc(), - rewriter.getIndexType(), row); - changed = true; + row = rewriter.create(op.getLoc(), rewriter.getIndexType(), + row); + changed = true; } Value column = op.getOperand(2); if (!column.getType().isa()) { - column = rewriter.create(op.getLoc(), - rewriter.getIndexType(), column); - changed = true; + column = rewriter.create(op.getLoc(), + rewriter.getIndexType(), column); + changed = true; } - if (!changed) return failure(); + if (!changed) + return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, column); + rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, + column); return success(changed); } }; void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); } - - class NumResultsOpTypeFix final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -108,21 +106,24 @@ class NumResultsOpTypeFix final : public OpRewritePattern { bool changed = false; Value handle = op->getOperand(0); - if (handle.getType().isa() && op->getResultTypes()[0].isa()) - return failure(); + if (handle.getType().isa() && + op->getResultTypes()[0].isa()) + return failure(); if (!handle.getType().isa()) { - handle = rewriter.create(op.getLoc(), - rewriter.getIndexType(), handle); - changed = true; + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; } - mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), handle); + mlir::Value res = rewriter.create( + op.getLoc(), rewriter.getIndexType(), handle); if (op->getResultTypes()[0].isa()) { - rewriter.replaceOp(op, res); + rewriter.replaceOp(op, res); } else { - rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], + res); } return success(changed); @@ -130,12 +131,10 @@ class NumResultsOpTypeFix final : public OpRewritePattern { }; void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); } - - // class ExecuteOpTypeFix final : public OpRewritePattern { // public: // using OpRewritePattern::OpRewritePattern; @@ -147,39 +146,44 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, // Value conn = op->getOperand(0); // Value command = op->getOperand(1); -// if (conn.getType().isa() && command.getType().isa() && op->getResultTypes()[0].isa()) +// if (conn.getType().isa() && command.getType().isa() +// && op->getResultTypes()[0].isa()) // return failure(); // if (!conn.getType().isa()) { // conn = rewriter.create(op.getLoc(), -// rewriter.getIndexType(), conn); +// rewriter.getIndexType(), +// conn); // changed = true; // } // if (command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); +// command = rewriter.create(op.getLoc(), +// LLVM::LLVMPointerType::get(rewriter.getI8Type()), +// command); // changed = true; // } - // if (command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// rewriter.getI64Type(), command); +// command = rewriter.create(op.getLoc(), +// rewriter.getI64Type(), +// command); // changed = true; // } // if (!command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// rewriter.getIndexType(), command); +// command = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), +// command); // changed = true; // } // if (!changed) return failure(); -// mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), conn, command); -// rewriter.replaceOp(op, res); +// mlir::Value res = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), conn, command); rewriter.replaceOp(op, res); // // if (op->getResultTypes()[0].isa()) { // // rewriter.replaceOp(op, res); // // } else { -// // rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); +// // rewriter.replaceOpWithNewOp(op, +// op->getResultTypes()[0], res); // // } // return success(changed); // } @@ -190,8 +194,7 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, // results.insert(context); // } - -template +template class UnparsedOpInnerCast final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -200,9 +203,10 @@ class UnparsedOpInnerCast final : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op->getOperand(0); - + auto cst = input.getDefiningOp(); - if (!cst) return failure(); + if (!cst) + return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), cst.getOperand()); return success(); @@ -210,29 +214,80 @@ class UnparsedOpInnerCast final : public OpRewritePattern { }; void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert >(context); + MLIRContext *context) { + results.insert>(context); } - -class SQLStringConcatOpCanonicalization final : public OpRewritePattern { +class SQLStringConcatOpCanonicalization final + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SQLStringConcatOp op, PatternRewriter &rewriter) const override { - - auto input1 = op->getOperand(0).getDefiningOp(); - auto input2 = op->getOperand(1).getDefiningOp(); - - if (!input1 || !input2) return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), (input1.getInput() + input2.getInput()).str()); - return success(); + // Whether we changed the state. If we make no simplifications we need to + // return failure otherwise we will infinite loop + bool changed = false; + // Operands to the simplified concat + SmallVector operands; + // Constants that we will merge, "current running constant" + SmallVector constants; + for (auto op : op->getOperands()) { + if (auto constOp = op.getDefiningOp()) { + constants.push_back(constOp); + continue; + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr += str.getInput().str(); + + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + constants.clear(); + if (auto concat = op.getDefiningOp()) { + changed = true; + for (auto op2 : concat->getOperands()) + operands.push_back(op2); + continue; + } + operands.push_back(op); + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr = nextStr + str.getInput().str(); + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + if (operands.size() == 0) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), ""); + return success(); + } + if (operands.size() == 1) { + rewriter.replaceOp(op, operands[0]); + return success(); + } + if (changed) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), operands); + return success(); + } + return failure(); } }; void SQLStringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); -} \ No newline at end of file +} diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp index 25a4776368e2..86125e1b7fc1 100644 --- a/lib/sql/Parser.cpp +++ b/lib/sql/Parser.cpp @@ -1,277 +1,320 @@ +#include +#include #include #include -#include -#include -#include "mlir/IR/Value.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinTypes.h" - #include "sql/SQLDialect.h" #include "sql/SQLOps.h" #include "sql/SQLTypes.h" - using namespace mlir; using namespace sql; enum class ParseType { - Nothing = 0, - Value = 1, - Attribute = 2, + Nothing = 0, + Value = 1, + Attribute = 2, }; struct ParseValue { private: - ParseType ty; - Value value; - Attribute attr; + ParseType ty; + Value value; + Attribute attr; + public: - ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} - ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { - assert(value); - } - ParseValue(Attribute attr) : ty(ParseType::Attribute), value(nullptr), attr(attr) {} + ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} + ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { + assert(value); + } + ParseValue(Attribute attr) + : ty(ParseType::Attribute), value(nullptr), attr(attr) {} - ParseType getType() const { - return ty; - } - Value getValue() const { - // llvm::errs() << "ty: " << type_value << "\n"; - // llvm::errs() << "value: " << value << "\n"; - // llvm::errs() << "attr: " << attr << "\n"; - assert(ty == ParseType::Value); - assert(value); - return value; - } - Attribute getAttr() const { - assert(ty == ParseType::Attribute); - return attr; - } + ParseType getType() const { return ty; } + Value getValue() const { + assert(ty == ParseType::Value); + assert(value); + return value; + } + Attribute getAttr() const { + assert(ty == ParseType::Attribute); + return attr; + } }; enum class ParseMode { - None, - Column, - Table, - Bool, + None, + Column, + Table, + Bool, + Clause }; - -template -std::ostream& operator<<(typename std::enable_if::value, std::ostream>::type& stream, const T& e) -{ - return stream << static_cast::type>(e); +template +std::ostream &operator<<( + typename std::enable_if::value, std::ostream>::type &stream, + const T &e) { + return stream << static_cast::type>(e); } class SQLParser { public: - Location loc; - OpBuilder &builder; - std::string sql; - unsigned int i; + Location loc; + OpBuilder &builder; + std::string sql; + unsigned int i; - static std::vector reservedWords; + static std::vector reservedWords; - - SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) : loc(loc), builder(builder), - sql(sql), i(i) { - } + SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) + : loc(loc), builder(builder), sql(sql), i(i) {} - std::string peek() { - auto [peeked, _] = peekWithLength(); - return peeked; - } + std::string peek() { + auto [peeked, _] = peekWithLength(); + return peeked; + } - std::string pop() { - auto [peeked, len] = peekWithLength(); - i += len; - popWhitespace(); - return peeked; - } + std::string pop() { + auto [peeked, len] = peekWithLength(); + i += len; + popWhitespace(); + return peeked; + } - void popWhitespace() { - // it doesn't recognize - while (i < sql.size() && (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { - i++; - } + void popWhitespace() { + // it doesn't recognize + while (i < sql.size() && + (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { + i++; } + } - std::pair peekWithLength() { - if (i >= sql.size()) { - return {"", 0}; - } - for (std::string rWord : reservedWords) { - auto token = sql.substr(i, std::min(sql.size()-i, rWord.size())); - std::transform(token.begin(), token.end(), token.begin(), ::toupper); - if (token == rWord) { - return {token, static_cast(token.size())}; - } - } - if (sql[i] == '\'') { // Quoted string - return peekQuotedStringWithLength(); - } - return peekIdentifierWithLength(); + std::pair peekWithLength() { + if (i >= sql.size()) { + return {"", 0}; } - - std::pair peekQuotedStringWithLength() { - if (sql.size() < i || sql[i] != '\'') { - return {"", 0}; - } - for (unsigned int j = i + 1; j < sql.size(); j++) { - if (sql[j] == '\'' && sql[j-1] != '\\') { - return {sql.substr(i + 1, j - (i + 1)), j - i + 2}; // +2 for the two quotes - } - } - return {"", 0}; + for (std::string rWord : reservedWords) { + auto token = sql.substr(i, std::min(sql.size() - i, rWord.size())); + std::transform(token.begin(), token.end(), token.begin(), ::toupper); + if (token == rWord) { + return {token, static_cast(token.size())}; + } } - - std::pair peekIdentifierWithLength() { - std::regex e("[a-zA-Z0-9_*]"); - for (unsigned int j = i; j < sql.size(); j++) { - if (!std::regex_match(std::string(1, sql[j]), e)) { - return {sql.substr(i, j - i), j - i}; - } - } - return {sql.substr(i), static_cast(sql.size())-i}; + if (sql[i] == '\'') { // Quoted string + return peekQuotedStringWithLength(); } + return peekIdentifierWithLength(); + } + std::pair peekQuotedStringWithLength() { + if (sql.size() < i || sql[i] != '\'') { + return {"", 0}; + } + for (unsigned int j = i + 1; j < sql.size(); j++) { + if (sql[j] == '\'' && sql[j - 1] != '\\') { + return {sql.substr(i + 1, j - (i + 1)), + j - i + 2}; // +2 for the two quotes + } + } + return {"", 0}; + } - bool is_number(std::string* s){ - std::string::const_iterator it = s->begin(); - while (it != s->end() && std::isdigit(*it)) ++it; - return !s->empty() && it == s->end(); + std::pair peekIdentifierWithLength() { + std::regex e("[a-zA-Z0-9_*]"); + for (unsigned int j = i; j < sql.size(); j++) { + if (!std::regex_match(std::string(1, sql[j]), e)) { + return {sql.substr(i, j - i), j - i}; + } } + return {sql.substr(i), static_cast(sql.size()) - i}; + } + bool is_number(std::string *s) { + std::string::const_iterator it = s->begin(); + while (it != s->end() && std::isdigit(*it)) + ++it; + return !s->empty() && it == s->end(); + } - // Parse the next command, if any - ParseValue parseNext(ParseMode mode) { - // for (unsigned int j = i; j < sql.size(); j++) { - // auto peekStr = peek(); - // pop(); - // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; - // } - if (i >= sql.size()) { - return ParseValue(); - } - auto peekStr = peek(); - llvm::errs() << "peekStr: " << peekStr << "\n"; - assert(peekStr.size() > 0); + // Parse the next command, if any + ParseValue parseNext(ParseMode mode) { + // for (unsigned int j = i; j < sql.size(); j++) { + // auto peekStr = peek(); + // pop(); + // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; + // } + if (i >= sql.size()) { + return ParseValue(); + } + auto peekStr = peek(); + llvm::errs() << "peekStr: " << peekStr << "\n"; + assert(peekStr.size() > 0); - if (peekStr == "SELECT") { - assert(mode == ParseMode::None || mode == ParseMode::Table); + if (peekStr == "SELECT") { + assert(mode == ParseMode::None || mode == ParseMode::Table); + pop(); + peekStr = peek(); + if (peekStr == "DISTINCT") { + pop(); + // do something different here + } + llvm::SmallVector columns; + bool hasColumns = true; + bool hasWhere = false; + int limit = -1; + Value table = nullptr; + Value where = nullptr; + while (true) { + peekStr = peek(); + if (peekStr == "") + break; + if (hasColumns) { + if (peekStr == "FROM") { pop(); - peekStr = peek(); - if (peekStr == "DISTINCT") { - pop(); - // do something different here - } - llvm::SmallVector columns; - bool hasColumns = true; - bool hasWhere = false; - bool selectAll = false; - int limit = -1; - Value table = nullptr; - while (true) { - peekStr = peek(); - if (hasColumns) { - if (peekStr == "FROM") { - pop(); - table = parseNext(ParseMode::Table).getValue(); - llvm::errs() << "table: " << table << "\n"; - hasColumns = false; - break; - } - if (peekStr == "*") { - pop(); - selectAll = true; - continue; - } - if (peekStr == ",") { - pop(); - llvm::errs() << "comma\n"; - continue; - } - ParseValue col = parseNext(ParseMode::Column); - if (col.getType() == ParseType::Nothing) { - hasColumns = false; - break; - } else { - columns.push_back(col.getValue()); - } - - } else if (peekStr == "WHERE") { - pop(); - hasWhere = true; - } else if (peekStr == "LIMIT"){ - pop(); - peekStr = peek(); - if (peekStr == "ALL"){ - pop(); - } else if (is_number(&peekStr)){ - pop(); - limit = std::stoi(peekStr); - } - } else { - // break; - assert(0 && " additional clauses like where/etc not yet handled"); - } - } - if (!table){ - llvm::errs() << " table is null: " << table << "\n"; - table = builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr("")).getResult(); + table = parseNext(ParseMode::Table).getValue(); + hasColumns = false; + } else { + if (peekStr == ",") { + pop(); + continue; } - // if (selectAll){ - // assert(table && "table cannot be null"); - // return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), table).getResult()); - // } else { - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table, selectALl, limit).getResult()); - // } - } else if (is_number(&peekStr)){ - pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Column) { - // do we need this?? - if (peekStr == "*") { - pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + ParseValue col = parseNext(ParseMode::Column); + if (col.getType() == ParseType::Nothing) { + hasColumns = false; + break; + } else { + columns.push_back(col.getValue()); } + } + } else if (peekStr == "WHERE") { + pop(); + ParseValue clause = parseNext(ParseMode::Bool); + if (clause.getType() == ParseType::Nothing) { + assert(0 && "where clause not recognized"); + } else { + where = builder.create(loc, ExprType::get(builder.getContext()), + clause.getValue()).getResult(); + } + } else if (peekStr == "LIMIT") { + pop(); + peekStr = peek(); + llvm::errs() << "limit: " << peekStr << "\n"; + if (peekStr == "ALL") { pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Table) { - pop(); - return ParseValue(builder.create(loc,ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Bool) { - // col = peekStr; - pop(); - - } else if (peekStr == "(") { + } else if (is_number(&peekStr)) { + llvm::errs() << "limit recognized: " + << "\n"; pop(); - ParseValue res = parseNext(ParseMode::None); - assert(peek() == ")"); - pop(); - return res; - } else if (peekStr == ")") { - return ParseValue(); - } - llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; - llvm_unreachable("Unknown token to parse"); - } + limit = std::stoi(peekStr); + } else { + assert(0 && "not yet handled limit var"); + } + } else { + // break; + llvm::errs() << "peekstr that throws an error" << peekStr << "\n"; + assert(0 && " additional clauses like where/etc not yet handled"); + } + } + if (!table) { + llvm::errs() << " table is null: " << table << "\n"; + table = builder.create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + } + if (!where){ + llvm::errs() << " where is null: " << table << "\n"; + Value clause = builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + where = builder.create(loc, ExprType::get(builder.getContext()), + clause).getResult(); + } + return ParseValue( + builder.create(loc, ExprType::get(builder.getContext()), + columns, table, where, limit).getResult()); + } else if (is_number(&peekStr)) { + pop(); + return ParseValue(builder.create(loc, + ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Column) { + if (peekStr == "*") { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext())) + .getResult()); + } + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Table) { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Bool) { + // col = peekStr; + ParseValue left = parseNext(ParseMode::Clause); + peekStr = peek(); + if (peekStr == "AND") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), right.getValue()).getResult() + ); + } else if (peekStr == "OR") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), + right.getValue()).getResult() + ); + } else return left; + + } else if (mode == ParseMode::Clause){ + std::string clause = pop(); + clause += " " + pop(); + clause += " " + pop(); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr(clause)).getResult() + ); + } else if (peekStr == "(") { + pop(); + ParseValue res = parseNext(ParseMode::None); + assert(peek() == ")"); + pop(); + return res; + } else if (peekStr == ")") { + return ParseValue(); + } + llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; + llvm_unreachable("Unknown token to parse"); + } }; std::vector SQLParser::reservedWords = { - "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", ",", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" -}; + "(", ")", ">=", "<=", "!=", + ",", "=", ">", "<", ",", + "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", + "DELETE FROM", "WHERE", "FROM", "SET", "AS"}; +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder &builder, + std::string str) { + SQLParser parser(loc, builder, std::string(str), 0); + auto resOp = parser.parseNext(ParseMode::None); -mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder& builder, std::string str) { - SQLParser parser(loc, builder, std::string(str), 0); - auto resOp = parser.parseNext(ParseMode::None); - - return resOp.getValue(); + return resOp.getValue(); } \ No newline at end of file diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index eaa26ddb2458..eb0483e4bb32 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "PassDetails.h" -#include "polygeist/Ops.h" #include "mlir/Analysis/CallGraph.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -20,10 +19,11 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "polygeist/Ops.h" +#include "sql/Passes/Passes.h" +#include "sql/SQLOps.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "sql/SQLOps.h" -#include "sql/Passes/Passes.h" #include #include @@ -41,7 +41,6 @@ struct SQLLower : public SQLLowerBase { } // end anonymous namespace - struct NumResultsOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -53,35 +52,33 @@ struct NumResultsOpLowering : public OpRewritePattern { symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared - auto rowsfn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQntuples"))); + auto rowsfn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQntuples"))); // 2) convert the args to valid args to postgres_getresult abi Value arg = op.getHandle(); - arg = rewriter.create(op.getLoc(), - rewriter.getI8Type(), arg); + arg = rewriter.create(op.getLoc(), rewriter.getI8Type(), + arg); arg = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); - - arg = rewriter.create(op.getLoc(), - rowsfn.getFunctionType().getInput(0), arg); + + arg = rewriter.create( + op.getLoc(), rowsfn.getFunctionType().getInput(0), arg); // 3) call and replace - Value args[] = {arg}; - - Value res = - rewriter.create(op.getLoc(), rowsfn, args) - ->getResult(0); + Value args[] = {arg}; + + Value res = rewriter.create(op.getLoc(), rowsfn, args) + ->getResult(0); - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), res); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + res); return success(); } }; - struct GetValueOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -93,8 +90,8 @@ struct GetValueOpLowering : public OpRewritePattern { symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared - auto valuefn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQgetvalue"))); + auto valuefn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQgetvalue"))); auto atoifn = dyn_cast_or_null( symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); @@ -102,48 +99,48 @@ struct GetValueOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value handle = op.getHandle(); handle = rewriter.create(op.getLoc(), - rewriter.getI64Type(), handle); + rewriter.getI64Type(), handle); handle = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); - handle = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(0), handle); + handle = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(0), handle); Value row = op.getRow(); - row = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(1), row); - Value column = op.getColumn(); - column = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(2), column); - - Value args[] = {handle, row, column}; - - Value res = - rewriter.create(op.getLoc(), valuefn, args) - ->getResult(0); + row = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(1), row); + Value column = op.getColumn(); + column = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(2), column); + + Value args[] = {handle, row, column}; + + Value res = rewriter.create(op.getLoc(), valuefn, args) + ->getResult(0); + + Value args2[] = {res}; - Value args2[] = {res}; - Value res2 = rewriter.create(op.getLoc(), atoifn, args2) ->getResult(0); if (op.getType() != res2.getType()) { - if (op.getType().isa()) - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - else if (auto IT = op.getType().dyn_cast()) { - auto IT2 = res2.getType().dyn_cast(); - if (IT.getWidth() < IT2.getWidth()) { - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - } else if (IT.getWidth() > IT2.getWidth()) { - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - } else assert(0 && "illegal integer type conversion"); - } else { - assert(0 && "illegal type conversion"); - } + if (op.getType().isa()) + res2 = rewriter.create(op.getLoc(), op.getType(), + res2); + else if (auto IT = op.getType().dyn_cast()) { + auto IT2 = res2.getType().dyn_cast(); + if (IT.getWidth() < IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else if (IT.getWidth() > IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else + assert(0 && "illegal integer type conversion"); + } else { + assert(0 && "illegal type conversion"); + } } rewriter.replaceOp(op, res2); @@ -151,20 +148,21 @@ struct GetValueOpLowering : public OpRewritePattern { } }; +struct ConstantStringOpLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -struct ConstantStringOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, + LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, PatternRewriter &rewriter) const final { auto module = op->getParentOfType(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(module); - for (auto u: op.getResult().getUsers()){ - if (isa(u)) return failure(); + for (auto u : op.getResult().getUsers()) { + if (isa(u)) + return failure(); } - auto expr = op.getInput().str(); + auto expr = op.getInput().str(); auto name = "str" + std::to_string((long long int)(Operation *)op); auto MT = MemRefType::get({expr.size() + 1}, rewriter.getI8Type()); // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {}); @@ -173,37 +171,93 @@ struct ConstantStringOpLowering : public OpRewritePattern data(expr.begin(), expr.end()); data.push_back('\0'); auto attr = DenseElementsAttr::get( - RankedTensorType::get(MT.getShape(), MT.getElementType()), data); - + RankedTensorType::get(MT.getShape(), MT.getElementType()), data); + auto loc = op.getLoc(); - rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); + rewriter.replaceOpWithNewOp( + op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); rewriter.setInsertionPointToStart(module.getBody()); - auto res = rewriter.create(loc, rewriter.getStringAttr(name), - mlir::StringAttr(), mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), /*alignment*/nullptr); + auto res = rewriter.create( + loc, rewriter.getStringAttr(name), mlir::StringAttr(), + mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), + /*alignment*/ nullptr); return success(); } }; -// struct StringConcatOpLowering : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(sql::SQLStringConcatOp op, -// PatternRewriter &rewriter) const final { -// auto module = op->getParentOfType(); +struct BoolToStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -// SymbolTableCollection symbolTable; -// symbolTable.getSymbolTable(module); - -// return success(); -// } -// }; + LogicalResult matchAndRewrite(sql::SQLBoolToStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value expr = op.getExpr(); + auto definingOp = expr.getDefiningOp(); + if (auto andOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "AND "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto orOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "OR "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto calcBoolOp = dyn_cast(definingOp)){ + auto expr = calcBoolOp.getExpr(); + if (expr == "") { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + ""); + rewriter.replaceOp(op, res); + return success(); + } + + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), expr); + Value args[] = {res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else { + assert(0 && "unknown type to convert to string"); + } + + return success(); + } +}; struct ToStringOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(sql::SQLToStringOp op, + LogicalResult matchAndRewrite(sql::SQLToStringOp op, PatternRewriter &rewriter) const final { auto module = op->getParentOfType(); @@ -213,72 +267,128 @@ struct ToStringOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value expr = op.getExpr(); auto definingOp = expr.getDefiningOp(); - if (auto selectOp = dyn_cast(definingOp)){ - Value current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); - bool prevColumn = false; - auto columns = selectOp.getColumns(); - for (mlir::Value v : columns) { - if (prevColumn) { - Value args[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), ", ") }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - } - Value col = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), v); - Value args[] = { col, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), " ")}; - col = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), args); - Value args2[] = { current, col }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), args2); - prevColumn = true; - } - auto tableOp = selectOp.getTable().getDefiningOp(); - if (tableOp) { - Value args[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "FROM ") }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - Value args2[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), selectOp.getTable()) }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args2); + if (auto selectOp = dyn_cast(definingOp)) { + Value current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); + + bool prevColumn = false; + auto columns = selectOp.getColumns(); + for (mlir::Value v : columns) { + if (prevColumn) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), ", ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); } - rewriter.replaceOp(op, current); - } else if (auto selectAllOp = dyn_cast(definingOp)){ - auto table = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), selectAllOp.getTable()); - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "SELECT * FROM "); - Value args[] = { res, table }; - res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - - rewriter.replaceOp(op, res); + Value col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), v); + Value args[] = { + col, + rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, col}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + prevColumn = true; + } + + auto tableOp = selectOp.getTable().getDefiningOp(); + if (tableOp) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "FROM ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getTable())}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + Value whereVal = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getWhere()); + if (whereVal) { + auto whereOp = selectOp.getWhere().getDefiningOp(); + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "WHERE ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, whereVal}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + + auto limit = selectOp.getLimit(); + if (limit >= 0) { + Value args[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "LIMIT ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + std::to_string(limit))}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + rewriter.replaceOp(op, current); } else if (auto tabOp = dyn_cast(definingOp)) { - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), tabOp.getExpr()); - rewriter.replaceOp(op, res); + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + tabOp.getExpr()); + Value args[] = { res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto allColOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "*"); + rewriter.replaceOp(op, res); } else if (auto colOp = dyn_cast(definingOp)) { - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), colOp.getExpr()); - rewriter.replaceOp(op, res); - } else if (auto intOp = dyn_cast(definingOp)){ - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), intOp.getExpr()); - llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; - rewriter.replaceOp(op, res); + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + colOp.getExpr()); + rewriter.replaceOp(op, res); + } else if (auto intOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + intOp.getExpr()); + llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; + rewriter.replaceOp(op, res); + } else if (auto whereOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + whereOp.getExpr()); + rewriter.replaceOp(op, res); } else { - assert(0 && "unknown type to convert to string"); + assert(0 && "unknown type to convert to string"); } - + return success(); } }; + + struct ExecuteOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -296,39 +406,37 @@ struct ExecuteOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value conn = op.getConn(); conn = rewriter.create(op.getLoc(), - rewriter.getI8Type(), conn); + rewriter.getI8Type(), conn); conn = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); - conn = rewriter.create(op.getLoc(), - executefn.getFunctionType().getInput(0), conn); - - Value command = op.getCommand(); - // auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp()); - command = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), command); + conn = rewriter.create( + op.getLoc(), executefn.getFunctionType().getInput(0), conn); + + Value command = op.getCommand(); + // auto name = "str" + std::to_string((long long int)(Operation + // *)command.getDefiningOp()); + command = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), command); llvm::errs() << "command: " << command << "\n"; llvm::errs() << "command type: " << command.getType() << "\n"; // auto type = MemRefType::get({-1}, rewriter.getI8Type()); - // auto globalOp = rewriter.create( - // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, name, - // mlir::Attribute(), + // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, + // name, mlir::Attribute(), // /* alignment */ 0, /* addrSpace */ 0); - // 3) call and replace - Value args[] = {conn, command}; - + Value args[] = {conn, command}; + Value res = rewriter.create(op.getLoc(), executefn, args) ->getResult(0); - res = rewriter.create(op.getLoc(), - LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); - res = rewriter.create( - op.getLoc(), rewriter.getI64Type(), res); - res = rewriter.create(op.getLoc(), - op.getType(), res); + res = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create(op.getLoc(), rewriter.getI64Type(), + res); + res = rewriter.create(op.getLoc(), op.getType(), res); rewriter.replaceOp(op, res); @@ -349,28 +457,26 @@ void SQLLower::runOnOperation() { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; mlir::Type rettypes[] = {builder.getI64Type()}; - auto fn = - builder.create(module.getLoc(), "PQntuples", - builder.getFunctionType(argtypes, rettypes)); + auto fn = builder.create( + module.getLoc(), "PQntuples", + builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( module, builder.getStringAttr("PQgetvalue")))) { - mlir::Type argtypes[] = { - MemRefType::get({-1}, builder.getI8Type()), - builder.getI64Type(), - builder.getI64Type()}; + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), + builder.getI64Type(), builder.getI64Type()}; mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; - auto fn = - builder.create(module.getLoc(), "PQgetvalue", - builder.getFunctionType(argtypes, rettypes)); + auto fn = builder.create( + module.getLoc(), "PQgetvalue", + builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } - if (!dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQexec")))) { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), MemRefType::get({-1}, builder.getI8Type())}; mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; @@ -383,7 +489,8 @@ void SQLLower::runOnOperation() { if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; - // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; + // mlir::Type argtypes[] = + // {LLVM::LLVMPointerType::get(builder.getI64Type())}; // todo use data layout mlir::Type rettypes[] = {builder.getI64Type()}; @@ -398,13 +505,13 @@ void SQLLower::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); patterns.insert(&getContext()); - for (auto *dialect : getContext().getLoadedDialects()) - dialect->getCanonicalizationPatterns(patterns); + dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : getContext().getRegisteredOperations()) - op.getCanonicalizationPatterns(patterns, &getContext()); + op.getCanonicalizationPatterns(patterns, &getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), @@ -416,5 +523,5 @@ namespace sql { std::unique_ptr createSQLLowerPass() { return std::make_unique(); } -} // namespace polygeist +} // namespace sql } // namespace mlir