Skip to content

Commit

Permalink
feat: implement whitelist (#119)
Browse files Browse the repository at this point in the history
### TL;DR

Implemented query validation to enhance security and prevent potential SQL injection attacks.

### What changed?

- Added a `ValidateQuery` function in `utils.go` to check for disallowed patterns and ensure only allowed functions are used in queries.
- Integrated query validation in the `ClickHouseConnector` methods for executing queries.
- Updated error handling in `logs_handlers.go` and `transactions_handlers.go` to potentially use `BadRequestError` for disallowed functions.

### How to test?

1. Try running queries with allowed functions (e.g., `sum`, `count`, `reinterpretAsUInt256`) and ensure they work as expected.
2. Attempt to use disallowed patterns or functions in queries and verify that they are rejected with appropriate error messages.
3. Test different types of queries (SELECT, INSERT, UPDATE, etc.) to confirm that only SELECT queries are allowed.

### Why make this change?

This change enhances the security of the application by preventing potential SQL injection attacks and restricting the use of potentially harmful functions or query patterns. It ensures that only safe, pre-approved functions can be used in queries, reducing the risk of unauthorized data access or manipulation.
  • Loading branch information
catalyst17 authored Oct 30, 2024
2 parents 6d744ac + 9cb0aac commit 32f151a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
49 changes: 49 additions & 0 deletions internal/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package common
import (
"fmt"
"math/big"
"regexp"
"strings"
"unicode"
)
Expand Down Expand Up @@ -169,3 +170,51 @@ func isType(word string) bool {

return types[word]
}

var allowedFunctions = map[string]struct{}{
"sum": {},
"count": {},
"reinterpretAsUInt256": {},
"reverse": {},
"unhex": {},
"substring": {},
"length": {},
"toUInt256": {},
"if": {},
}

var disallowedPatterns = []string{
`(?i)\b(UNION|INSERT|DELETE|UPDATE|DROP|CREATE|ALTER|TRUNCATE|EXEC|;|--)`,
}

// validateQuery checks the query for disallowed patterns and ensures only allowed functions are used.
func ValidateQuery(query string) error {
// Check for disallowed patterns
for _, pattern := range disallowedPatterns {
matched, err := regexp.MatchString(pattern, query)
if err != nil {
return fmt.Errorf("error checking disallowed patterns: %v", err)
}
if matched {
return fmt.Errorf("query contains disallowed keywords or patterns")
}
}

// Ensure the query is a SELECT statement
trimmedQuery := strings.TrimSpace(strings.ToUpper(query))
if !strings.HasPrefix(trimmedQuery, "SELECT") {
return fmt.Errorf("only SELECT queries are allowed")
}

// Extract function names and validate them
functionPattern := regexp.MustCompile(`(?i)(\b\w+\b)\s*\(`)
matches := functionPattern.FindAllStringSubmatch(query, -1)
for _, match := range matches {
funcName := match[1]
if _, ok := allowedFunctions[funcName]; !ok {
return fmt.Errorf("function '%s' is not allowed", funcName)
}
}

return nil
}
2 changes: 2 additions & 0 deletions internal/handlers/logs_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
aggregatesResult, err := mainStorage.GetAggregations("logs", qf)
if err != nil {
log.Error().Err(err).Msg("Error querying aggregates")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand All @@ -180,6 +181,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
logsResult, err := mainStorage.GetLogs(qf)
if err != nil {
log.Error().Err(err).Msg("Error querying logs")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand Down
4 changes: 3 additions & 1 deletion internal/handlers/transactions_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
aggregatesResult, err := mainStorage.GetAggregations("transactions", qf)
if err != nil {
log.Error().Err(err).Msg("Error querying aggregates")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand All @@ -181,7 +182,8 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
// Retrieve logs data
transactionsResult, err := mainStorage.GetTransactions(qf)
if err != nil {
log.Error().Err(err).Msg("Error querying tran")
log.Error().Err(err).Msg("Error querying transactions")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand Down
12 changes: 12 additions & 0 deletions internal/storage/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ func (c *ClickHouseConnector) GetBlocks(qf QueryFilter) (blocks []common.Block,

query += getLimitClause(int(qf.Limit))

if err := common.ValidateQuery(query); err != nil {
return nil, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return nil, err
Expand Down Expand Up @@ -369,6 +372,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
}

if err := common.ValidateQuery(query); err != nil {
return QueryResult[interface{}]{}, err
}
// Execute the query
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
Expand Down Expand Up @@ -421,6 +427,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
func executeQuery[T any](c *ClickHouseConnector, table, columns string, qf QueryFilter, scanFunc func(driver.Rows) (T, error)) (QueryResult[T], error) {
query := c.buildQuery(table, columns, qf)

if err := common.ValidateQuery(query); err != nil {
return QueryResult[T]{}, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return QueryResult[T]{}, err
Expand Down Expand Up @@ -856,6 +865,9 @@ func (c *ClickHouseConnector) GetTraces(qf QueryFilter) (traces []common.Trace,

query += getLimitClause(int(qf.Limit))

if err := common.ValidateQuery(query); err != nil {
return nil, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return nil, err
Expand Down

0 comments on commit 32f151a

Please sign in to comment.