Skip to content

Commit

Permalink
Merge pull request #8 from Exion007/similar-functions-deletion
Browse files Browse the repository at this point in the history
Similar functions deletion
  • Loading branch information
XejgaToast authored Sep 30, 2024
2 parents 673e935 + c0affad commit 34734b1
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def a(n: si64) -> si64 {
return 100;
}


print(a(5));
print(a(10));
211 changes: 123 additions & 88 deletions src/compiler/lowering/SpecializeGenericFunctionsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,63 @@
#include <algorithm>
#include <iostream>
#include <regex>
#include <uuid/uuid.h>
#include <openssl/sha.h>

using namespace mlir;

/**
* @brief Specialize functions by calculating intermediate results (including return value) inside functions in compile time and
* create new specialized functions. It also makes sure to not create too many specializations (deep recursion), maintains a call graph, a hashmap of function bodies
* and a set of recursive cycles to decide whether it makes sense to specialize or not.
*/

namespace {

// Helper function to generate a hash from a string (using SHA256)
/**
* @brief Generates a SHA-512 hash of the input IR (Intermediate Representation) string.
*
* This function computes the SHA-512 hash of a given string representing an
* Intermediate Representation (IR) to create a unique identifier for the IR.
* This can be useful for detecting changes in the IR or for comparing two
* IR functions.
*
* @param irString The IR string to hash.
* @return A hexadecimal string representation of the SHA-512 hash.
*/
std::string hashIRString(const std::string &irString) {
unsigned char hash[SHA256_DIGEST_LENGTH];
SHA256((unsigned char*)irString.c_str(), irString.size(), hash);
unsigned char hash[SHA512_DIGEST_LENGTH];
SHA512((unsigned char*)irString.c_str(), irString.size(), hash);

std::stringstream ss;
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
for (int i = 0; i < SHA512_DIGEST_LENGTH; ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << (int)hash[i];
}
return ss.str();
}

// Function to generate a UUID string
std::string generateUUID() {
uuid_t uuid;
uuid_generate(uuid);
char uuidStr[37]; // UUIDs are 36 characters + null terminator
uuid_unparse(uuid, uuidStr);
return std::string(uuidStr);
}

std::string getIRString(func::FuncOp funcOp) {
std::string irString;
llvm::raw_string_ostream stream(irString);
funcOp.print(stream);
return irString;
}

/**
* @brief Split any string into a vector of lines at all linebreaks of the string
*
* @param str The string to split
* @return The vector containing all lines
*/
std::vector<std::string> splitIntoLines(const std::string &str) {
std::vector<std::string> lines;
std::stringstream ss(str);
Expand Down Expand Up @@ -113,7 +146,14 @@ namespace {
);
}


/**
* @brief Create a unique specialized name for the function. If the specialized function has been created already, the same name gets returned
*
* @param functionName The name of the template Function
* @param inputTypes The input types
* @param inputValues The input values
* @return A unique function name for the specialized function
*/
std::string uniqueSpecializedFuncName(const std::string &functionName, TypeRange inputTypes, ValueRange inputValues) {
//static unsigned functionUniqueId = 0;
// Creating an empty string to store the new unique specialized function name
Expand Down Expand Up @@ -153,63 +193,6 @@ namespace {
return output;
}

/**
* @brief Extracts variable names from a given line of code.
*
* This function uses a regular expression to match variable names in the input line.
* Variable names must start with a letter or underscore and may contain letters, digits, or underscores.
*
* @param line The input line of code from which to extract variable names.
* @return A vector of strings containing the variable names found in the line.
*
* @example
* std::vector<std::string> vars = extractVariablesFromLine("int x = y + 10;");
* // vars will contain: ["int", "x", "y"]
*/
std::vector<std::string> extractVariablesFromLine(const std::string &line) {
std::vector<std::string> variables;
std::regex var_regex(R"(\b([a-zA-Z_][a-zA-Z0-9_]*)\b)");
auto words_begin = std::sregex_iterator(line.begin(), line.end(), var_regex);
auto words_end = std::sregex_iterator();

for (std::sregex_iterator it = words_begin; it != words_end; ++it) {
std::smatch match = *it;
std::string var = match.str(1);
variables.push_back(var);
}
return variables;
}

/**
* @brief Normalizes commutative operations by ensuring operands are in alphabetical order.
*
* This function scans a line of code for commutative operations (e.g., addition `+`)
* and reorders the operands so that the smaller (alphabetically) operand appears first.
* This helps to treat equivalent expressions like `a + b` and `b + a` as the same.
*
* @param line The input line of code to normalize.
* @return A string with the normalized commutative operations.
*
* @example
* std::string normalized = normalizeCommutativeOperations("y + x");
* // normalized will be: "x + y"
*/
std::string normalizeCommutativeOperations(const std::string &line) {
std::regex add_regex(R"((\w+)\s*\+\s*(\w+))");
std::smatch match;
std::string normalizedLine = line;

while (std::regex_search(normalizedLine, match, add_regex)) {
std::string operand1 = match[1];
std::string operand2 = match[2];
if (operand1 > operand2) std::swap(operand1, operand2);
std::string replacement = operand1 + " + " + operand2;
normalizedLine = match.prefix().str() + replacement + match.suffix().str();
}

return normalizedLine;
}

/**
* @brief Get argument types for the specialized version of a template function.
* @param functionType The types of the template function.
Expand Down Expand Up @@ -315,6 +298,9 @@ namespace {
std::unordered_map<std::string, func::FuncOp> hashToFuncMap;
std::map<func::FuncOp, std::vector<std::string>> IRRepresentations;
// Can store all callOps in a list

// Map to track all functions that share the same base hash (including those with newHash)
std::unordered_map<std::string, std::set<func::FuncOp>> baseHashMap;

const DaphneUserConfig& userConfig;
std::shared_ptr<spdlog::logger> logger;
Expand All @@ -325,6 +311,15 @@ namespace {
}

private:

/**
* @brief Detects recursion in the callStack (Note: This function gets utilized by findRecursions)
*
* @param func The function name to find it in the callStack
* @param visitedInGraph already visited functions
* @param callStack the callStack
*
*/
void detectRecursion(const std::string &func, std::set<std::string> &visitedInGraph, std::vector<std::string> &callStack) {
// If function is already on the stack, we found a recursion
auto it = std::find(callStack.begin(), callStack.end(), func);
Expand Down Expand Up @@ -355,7 +350,9 @@ namespace {
callStack.pop_back();
}

// Function to initiate recursion detection for all functions in the call graph
/** @brief Find recursion in the whole call graph (using DFS algorithm) and adds them to recursiveCalls
*
*/
void findRecursions() {
// Clear previous recursive calls
recursiveCalls.clear();
Expand All @@ -371,13 +368,10 @@ namespace {
}
}




/**
* @brief Update the callGraph map
*
* @param func The specialized function
* @return Nothing (could return error code?)
*/
void updateCallGraph(func::FuncOp func) {
// Get the module containing this function
Expand Down Expand Up @@ -454,7 +448,7 @@ namespace {
if(Operation * co = CompilerUtils::constantOfAnyType(v)) {
// Clone the constant operation into the function body.
Operation * coNew = co->clone();
+ builder.insert(coNew);
builder.insert(coNew);
// Replace all uses of the corresponding block argument by the newly inserted constant.
specializedFuncBodyBlock.getArgument(i).replaceAllUsesWith(coNew->getResult(0));
// TODO We could even remove the corresponding function argument.
Expand Down Expand Up @@ -517,7 +511,9 @@ namespace {
}

/**
* @brief Recursively specializes all functions within a `FuncOp` based on calls to the functions
* @brief Recursively specializes all functions within a `FuncOp` based on calls to the functions. And also ...
* ... keep track of how often recursive calls had happened and prevent recursion if too many recursive calls happened in a cycle.
* ... keep track of a function hashmap which is used to find and delete duplicate functions with different function names (prevent non beneficial specialization)
* @param function The `FuncOp` to scan for function specializations
*/
void specializeCallsInFunction(func::FuncOp function) {
Expand Down Expand Up @@ -576,20 +572,32 @@ namespace {


if (specialize) {
// Canonicalize the called function using MLIR's built-in canonicalizer
func::FuncOp specializedFunc = createOrReuseSpecialization(callOp.getOperandTypes(), callOp.getOperands(), calledFunction, callOp.getLoc());

// Canonicalize and add template function to baseHashMap (because specialization could be non beneficial)
mlir::PassManager pm(function.getContext());
pm.addPass(mlir::createCanonicalizerPass());
if (failed(pm.run(specializedFunc))) {
if (failed(pm.run(calledFunction))) {
throw std::runtime_error("Failed to canonicalize called function.");
}
std::vector<std::string> templateFunctionLines = splitIntoLines(getIRString(calledFunction));
size_t firstQuote = templateFunctionLines[0].find('"');
if(firstQuote != std::string::npos) {
size_t secondQuote = templateFunctionLines[0].find('"', firstQuote+1);
if(secondQuote != std::string::npos) {
templateFunctionLines[0] = templateFunctionLines[0].substr(secondQuote+1);
}
}
std::string templateFunctionIR = std::accumulate(templateFunctionLines.begin(), templateFunctionLines.end(), std::string());
std::string templateFuncHash = hashIRString(templateFunctionIR);
baseHashMap[templateFuncHash].insert(calledFunction);
IRRepresentations[calledFunction] = templateFunctionLines;
// Canonicalize the called function using MLIR's built-in canonicalizer
func::FuncOp specializedFunc = createOrReuseSpecialization(callOp.getOperandTypes(), callOp.getOperands(), calledFunction, callOp.getLoc());

// Get the canonicalized IR as a string
std::vector<std::string> canonicalIRLines = splitIntoLines(getIRString(specializedFunc));

// Only keep function signature, ignore name
size_t firstQuote = canonicalIRLines[0].find('"');
firstQuote = canonicalIRLines[0].find('"');
if(firstQuote != std::string::npos) {
size_t secondQuote = canonicalIRLines[0].find('"', firstQuote+1);
if(secondQuote != std::string::npos) {
Expand All @@ -599,33 +607,60 @@ namespace {

std::string canonicalIR = std::accumulate(canonicalIRLines.begin(), canonicalIRLines.end(), std::string());

// Hash the canonicalized IR string
std::string hash = hashIRString(canonicalIR);
std::string baseHash = hashIRString(canonicalIR);
std::string finalHash = baseHash;
bool isDuplicate = false;

// Check if the base hash already exists in baseHashMap
if (baseHashMap.find(baseHash) != baseHashMap.end()) {
// Iterate through functions that share the same base hash
for (func::FuncOp existingFunc : baseHashMap[baseHash]) {
std::vector<std::string> existingIRLines = IRRepresentations[existingFunc];
std::string existingCanonicalIR = std::accumulate(existingIRLines.begin(), existingIRLines.end(), std::string());

if (canonicalIR == existingCanonicalIR && existingFunc.getName().str() != specializedFunc.getName().str()) {
// Functions are identical, reuse the existing one
callOp.setCalleeAttr(existingFunc.getSymNameAttr());
if (fixResultTypes(callOp->getResults(), existingFunc.getFunctionType())) {
inferTypesInFunction(function);
}
specializedFunc.erase();
functions.erase(specializedName);
called.insert(existingFunc);
isDuplicate = true;
break;
}
}

// Check if this specialization already exists using the hash
if (hashToFuncMap.count(hash) && hashToFuncMap[hash].getName().str() != specializedFunc.getName().str()) {
// Reuse existing function
func::FuncOp existingFunc = hashToFuncMap[hash];
callOp.setCalleeAttr(existingFunc.getSymNameAttr());
if (fixResultTypes(callOp->getResults(), existingFunc.getFunctionType())) {
inferTypesInFunction(function);
// If no identical function found, generate a UUID for a new unique hash
if (!isDuplicate) {
std::string uuid = generateUUID(); // Use UUID for uniqueness
finalHash = baseHash + "_" + uuid;
}
specializedFunc.erase();
functions.erase(specializedName);
called.insert(existingFunc);
} else {
// If the baseHash does not exist in baseHashMap, add the function directly
finalHash = baseHash; // No need for UUID as there's no collision
}

// If no duplicate found, store the new function with its unique hash
if (!isDuplicate) {
callOp.setCalleeAttr(specializedFunc.getSymNameAttr());
if (fixResultTypes(callOp->getResults(), specializedFunc.getFunctionType())) {
inferTypesInFunction(function);
}
specializeCallsInFunction(specializedFunc);
called.insert(specializedFunc);

// Store the hash and the canonicalized IR for future reuse
hashToFuncMap[hash] = specializedFunc;
// Store the function in both hashToFuncMap and baseHashMap
hashToFuncMap[finalHash] = specializedFunc;
baseHashMap[baseHash].insert(specializedFunc);
IRRepresentations[specializedFunc] = canonicalIRLines;
}
} else {
functions.insert({calledFuncName, calledFunction});

// Keep track of the hash from template function because specialization could be non beneficial

callOp.setCalleeAttr(calledFunction.getSymNameAttr());
if (fixResultTypes(callOp->getResults(), calledFunction.getFunctionType())) {
inferTypesInFunction(function);
Expand Down

0 comments on commit 34734b1

Please sign in to comment.