Skip to content

Commit

Permalink
safety checks
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Nov 18, 2024
1 parent a2e5ed3 commit 2aea329
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 18 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ Create a DuckDB SQL Macro and save it somewhere. Here's an [example](https://gis
Load your remote macro onto your system using a URL:

```sql
D SELECT load_macro_from_url('https://gist.github.com/lmangani/518215a68e674ac662537d518799b893/raw/5f305480fdd7468f4ecda3686011bab8e8e711bf/bsky.sql') as res;
┌─────────────────────────────┐
│ res │
varchar
├─────────────────────────────┤
│ Successfully loaded macro │
└─────────────────────────────┘
┌─────────────────────────────────────────┐
│ res │
varchar
├─────────────────────────────────────────┤
│ Successfully loaded macro: search_posts │
└─────────────────────────────────────────┘
```

Use your new macro and have fun:
Expand Down
60 changes: 49 additions & 11 deletions src/webxtension_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,39 +83,71 @@ static bool ContainsMacroDefinition(const std::string &content) {
// Parse Function Name
static std::string ExtractMacroName(const std::string &macro_sql) {
try {
// Convert to uppercase for case-insensitive matching
std::string upper_sql = StringUtil::Upper(macro_sql);

// Find the MACRO keyword
size_t macro_pos = upper_sql.find("MACRO");
if (macro_pos == std::string::npos) {
return "unknown";
}

// Find the start of the name (after MACRO and any whitespace)

size_t name_start = macro_pos + 5; // length of "MACRO"
while (name_start < upper_sql.length() && std::isspace(upper_sql[name_start])) {
name_start++;
}

// Find the end of the name (before the opening parenthesis)

size_t name_end = upper_sql.find('(', name_start);
if (name_end == std::string::npos) {
return "unknown";
}

// Trim any trailing whitespace

while (name_end > name_start && std::isspace(upper_sql[name_end - 1])) {
name_end--;
}

// Get the original case version of the name from the input string

return macro_sql.substr(name_start, name_end - name_start);
} catch (...) {
return "unknown";
}
}

// Helper function to check for potentially dangerous SQL commands
static std::pair<bool, std::string> ContainsDangerousCommands(const std::string &sql) {
const std::vector<std::string> dangerous_commands = {
"DELETE", "DROP", "TRUNCATE", "ALTER", "GRANT", "REVOKE",
"CREATE USER", "ALTER USER", "DROP USER",
"CREATE DATABASE", "DROP DATABASE",
"EXEC", "EXECUTE",
"SHUTDOWN", "RESTART",
"SET GLOBAL", "SET SYSTEM",
"LOAD EXTENSION", "UNLOAD EXTENSION",
"ATTACH", "DETACH",
"COPY", "EXPORT",
"UPDATE", "MERGE"
};

std::string upper_sql = StringUtil::Upper(sql);
std::vector<std::string> found_commands;

for (const auto& cmd : dangerous_commands) {
if (upper_sql.find(cmd) != std::string::npos) {
found_commands.push_back(cmd);
}
}

if (!found_commands.empty()) {
std::string warning = "Warning: SQL contains potentially dangerous commands: ";
for (size_t i = 0; i < found_commands.size(); i++) {
warning += found_commands[i];
if (i < found_commands.size() - 1) {
warning += ", ";
}
}
warning += ". Please review the macro carefully before using it.";
return std::make_pair(true, warning);
}

return std::make_pair(false, "");
}

// Function to fetch and create macro from URL
static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Vector &result, DatabaseInstance *db_instance) {
auto &context = state.GetContext();
Expand All @@ -142,6 +174,12 @@ static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Ve
// Get the SQL content
std::string macro_sql = res->body;

// Check for dangerous commands
auto dangerous_check = ContainsDangerousCommands(macro_sql);
if (dangerous_check.first) {
throw std::runtime_error(dangerous_check.second);
}

// Replace all \r\n with \n
macro_sql = StringUtil::Replace(macro_sql, "\r\n", "\n");
// Replace any remaining \r with \n
Expand Down

0 comments on commit 2aea329

Please sign in to comment.