diff --git a/internal/common/utils.go b/internal/common/utils.go index a0ffc7c..9812cf6 100644 --- a/internal/common/utils.go +++ b/internal/common/utils.go @@ -3,6 +3,7 @@ package common import ( "fmt" "math/big" + "regexp" "strings" "unicode" ) @@ -169,3 +170,51 @@ func isType(word string) bool { return types[word] } + +var allowedFunctions = map[string]struct{}{ + "sum": {}, + "count": {}, + "reinterpretAsUInt256": {}, + "reverse": {}, + "unhex": {}, + "substring": {}, + "length": {}, + "toUInt256": {}, + "if": {}, +} + +var disallowedPatterns = []string{ + `(?i)\b(UNION|INSERT|DELETE|UPDATE|DROP|CREATE|ALTER|TRUNCATE|EXEC|;|--)`, +} + +// validateQuery checks the query for disallowed patterns and ensures only allowed functions are used. +func ValidateQuery(query string) error { + // Check for disallowed patterns + for _, pattern := range disallowedPatterns { + matched, err := regexp.MatchString(pattern, query) + if err != nil { + return fmt.Errorf("error checking disallowed patterns: %v", err) + } + if matched { + return fmt.Errorf("query contains disallowed keywords or patterns") + } + } + + // Ensure the query is a SELECT statement + trimmedQuery := strings.TrimSpace(strings.ToUpper(query)) + if !strings.HasPrefix(trimmedQuery, "SELECT") { + return fmt.Errorf("only SELECT queries are allowed") + } + + // Extract function names and validate them + functionPattern := regexp.MustCompile(`(?i)(\b\w+\b)\s*\(`) + matches := functionPattern.FindAllStringSubmatch(query, -1) + for _, match := range matches { + funcName := match[1] + if _, ok := allowedFunctions[funcName]; !ok { + return fmt.Errorf("function '%s' is not allowed", funcName) + } + } + + return nil +} diff --git a/internal/handlers/logs_handlers.go b/internal/handlers/logs_handlers.go index 3e1ddd7..c67d605 100644 --- a/internal/handlers/logs_handlers.go +++ b/internal/handlers/logs_handlers.go @@ -170,6 +170,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) { aggregatesResult, err := mainStorage.GetAggregations("logs", qf) if err != nil { log.Error().Err(err).Msg("Error querying aggregates") + // TODO: might want to choose BadRequestError if it's due to not-allowed functions api.InternalErrorHandler(c) return } @@ -180,6 +181,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) { logsResult, err := mainStorage.GetLogs(qf) if err != nil { log.Error().Err(err).Msg("Error querying logs") + // TODO: might want to choose BadRequestError if it's due to not-allowed functions api.InternalErrorHandler(c) return } diff --git a/internal/handlers/transactions_handlers.go b/internal/handlers/transactions_handlers.go index 62e4a50..8d7e985 100644 --- a/internal/handlers/transactions_handlers.go +++ b/internal/handlers/transactions_handlers.go @@ -172,6 +172,7 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string aggregatesResult, err := mainStorage.GetAggregations("transactions", qf) if err != nil { log.Error().Err(err).Msg("Error querying aggregates") + // TODO: might want to choose BadRequestError if it's due to not-allowed functions api.InternalErrorHandler(c) return } @@ -181,7 +182,8 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string // Retrieve logs data transactionsResult, err := mainStorage.GetTransactions(qf) if err != nil { - log.Error().Err(err).Msg("Error querying tran") + log.Error().Err(err).Msg("Error querying transactions") + // TODO: might want to choose BadRequestError if it's due to not-allowed functions api.InternalErrorHandler(c) return } diff --git a/internal/storage/clickhouse.go b/internal/storage/clickhouse.go index dbbd237..7083d1c 100644 --- a/internal/storage/clickhouse.go +++ b/internal/storage/clickhouse.go @@ -301,6 +301,9 @@ func (c *ClickHouseConnector) GetBlocks(qf QueryFilter) (blocks []common.Block, query += getLimitClause(int(qf.Limit)) + if err := common.ValidateQuery(query); err != nil { + return nil, err + } rows, err := c.conn.Query(context.Background(), query) if err != nil { return nil, err @@ -369,6 +372,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que query += fmt.Sprintf(" GROUP BY %s", groupByColumns) } + if err := common.ValidateQuery(query); err != nil { + return QueryResult[interface{}]{}, err + } // Execute the query rows, err := c.conn.Query(context.Background(), query) if err != nil { @@ -421,6 +427,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que func executeQuery[T any](c *ClickHouseConnector, table, columns string, qf QueryFilter, scanFunc func(driver.Rows) (T, error)) (QueryResult[T], error) { query := c.buildQuery(table, columns, qf) + if err := common.ValidateQuery(query); err != nil { + return QueryResult[T]{}, err + } rows, err := c.conn.Query(context.Background(), query) if err != nil { return QueryResult[T]{}, err @@ -856,6 +865,9 @@ func (c *ClickHouseConnector) GetTraces(qf QueryFilter) (traces []common.Trace, query += getLimitClause(int(qf.Limit)) + if err := common.ValidateQuery(query); err != nil { + return nil, err + } rows, err := c.conn.Query(context.Background(), query) if err != nil { return nil, err