Skip to content

Commit

Permalink
feat: dynamic resolving of aggrergates results types
Browse files Browse the repository at this point in the history
  • Loading branch information
catalyst17 committed Oct 25, 2024
1 parent 6a03b6a commit 7698103
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 15 deletions.
206 changes: 191 additions & 15 deletions internal/storage/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"math/big"
"reflect"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -349,24 +350,20 @@ func (c *ClickHouseConnector) GetLogs(qf QueryFilter) (QueryResult[common.Log],

func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (QueryResult[interface{}], error) {
// Build the SELECT clause with aggregates
columns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", columns, c.cfg.Database, table)
selectColumns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", selectColumns, c.cfg.Database, table)

// Apply filters
if qf.ChainId != nil && qf.ChainId.Sign() > 0 {
query = addFilterParams("chain_id", qf.ChainId.String(), query)
}
query = addContractAddress(table, query, qf.ContractAddress)

if qf.Signature != "" {
query += fmt.Sprintf(" AND topic_0 = '%s'", qf.Signature)
}

for key, value := range qf.FilterParams {
query = addFilterParams(key, strings.ToLower(value), query)
}

// Add GROUP BY clause if specified
if len(qf.GroupBy) > 0 {
groupByColumns := strings.Join(qf.GroupBy, ", ")
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
Expand All @@ -379,28 +376,45 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
}
defer rows.Close()

columnNames := rows.Columns()
columnTypes := rows.ColumnTypes()

// Collect results
var aggregates []map[string]interface{}
for rows.Next() {
columns := rows.Columns()
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
values := make([]interface{}, len(columnNames))

// Assign Go types based on ClickHouse types
for i, colType := range columnTypes {
dbType := colType.DatabaseTypeName()
values[i] = mapClickHouseTypeToGoType(dbType)
}

if err := rows.Scan(valuePtrs...); err != nil {
return QueryResult[interface{}]{}, err
if err := rows.Scan(values...); err != nil {
return QueryResult[interface{}]{}, fmt.Errorf("failed to scan row: %w", err)
}

// Prepare the result map for the current row
result := make(map[string]interface{})
for i, col := range columns {
result[col] = values[i]
for i, colName := range columnNames {
valuePtr := values[i]
value := getUnderlyingValue(valuePtr)

// Convert *big.Int to string
if bigIntValue, ok := value.(big.Int); ok {
result[colName] = BigInt{Int: bigIntValue}
} else {
result[colName] = value
}
}

aggregates = append(aggregates, result)
}

if err := rows.Err(); err != nil {
return QueryResult[interface{}]{}, fmt.Errorf("row iteration error: %w", err)
}

return QueryResult[interface{}]{Data: nil, Aggregates: aggregates}, nil
}

Expand Down Expand Up @@ -1056,3 +1070,165 @@ func (c *ClickHouseConnector) InsertBlockData(data *[]common.BlockData) error {
}
return nil
}

func mapClickHouseTypeToGoType(dbType string) interface{} {
// Handle LowCardinality types
if strings.HasPrefix(dbType, "LowCardinality(") {
dbType = dbType[len("LowCardinality(") : len(dbType)-1]
}

// Handle Nullable types
isNullable := false
if strings.HasPrefix(dbType, "Nullable(") {
isNullable = true
dbType = dbType[len("Nullable(") : len(dbType)-1]
}

// Handle Array types
if strings.HasPrefix(dbType, "Array(") {
elementType := dbType[len("Array(") : len(dbType)-1]
// For arrays, we'll use slices of pointers to the element type
switch elementType {
case "String", "FixedString":
return new([]*string)
case "Int8", "Int16", "Int32", "Int64":
return new([]*int64)
case "UInt8", "UInt16", "UInt32", "UInt64":
return new([]*uint64)
case "Float32", "Float64":
return new([]*float64)
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
return new([]*big.Float)
// Add more cases as needed
default:
return new([]interface{})
}
}

// Handle parameterized types by extracting the base type
baseType := dbType
if idx := strings.Index(dbType, "("); idx != -1 {
baseType = dbType[:idx]
}

// Map basic data types
switch baseType {
// Signed integers
case "Int8":
if isNullable {
return new(*int8)
}
return new(int8)
case "Int16":
if isNullable {
return new(*int16)
}
return new(int16)
case "Int32":
if isNullable {
return new(*int32)
}
return new(int32)
case "Int64":
if isNullable {
return new(*int64)
}
return new(int64)
// Unsigned integers
case "UInt8":
if isNullable {
return new(*uint8)
}
return new(uint8)
case "UInt16":
if isNullable {
return new(*uint16)
}
return new(uint16)
case "UInt32":
if isNullable {
return new(*uint32)
}
return new(uint32)
case "UInt64":
if isNullable {
return new(*uint64)
}
return new(uint64)
// Floating-point numbers
case "Float32":
if isNullable {
return new(*float32)
}
return new(float32)
case "Float64":
if isNullable {
return new(*float64)
}
return new(float64)
// Decimal types
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
if isNullable {
return new(*big.Float)
}
return new(big.Float)
// String types
case "String", "FixedString", "UUID", "IPv4", "IPv6":
if isNullable {
return new(*string)
}
return new(string)
// Enums
case "Enum8", "Enum16":
if isNullable {
return new(*string)
}
return new(string)
// Date and time types
case "Date", "Date32", "DateTime", "DateTime64":
if isNullable {
return new(*time.Time)
}
return new(time.Time)
// Big integers
case "Int128", "UInt128", "Int256", "UInt256":
if isNullable {
return new(*big.Int)
}
return new(big.Int)
default:
// For unknown types, use interface{}
return new(interface{})
}
}

type BigInt struct {
big.Int
}

func (b BigInt) MarshalJSON() ([]byte, error) {
return []byte(`"` + b.String() + `"`), nil
}

func getUnderlyingValue(valuePtr interface{}) interface{} {
v := reflect.ValueOf(valuePtr)

// Handle nil values
if !v.IsValid() {
return nil
}

// Handle pointers and interfaces
for {
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
if v.IsNil() {
return nil
}
v = v.Elem()
continue
}
break
}

return v.Interface()
}
134 changes: 134 additions & 0 deletions internal/storage/clickhouse_connector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package storage

import (
"math/big"
"reflect"
"testing"
"time"
)

// TestMapClickHouseTypeToGoType tests the mapClickHouseTypeToGoType function
func TestMapClickHouseTypeToGoType(t *testing.T) {
testCases := []struct {
dbType string
expectedType interface{}
}{
// Signed integers
{"Int8", int8(0)},
{"Nullable(Int8)", (**int8)(nil)},
{"Int16", int16(0)},
{"Nullable(Int16)", (**int16)(nil)},
{"Int32", int32(0)},
{"Nullable(Int32)", (**int32)(nil)},
{"Int64", int64(0)},
{"Nullable(Int64)", (**int64)(nil)},
// Unsigned integers
{"UInt8", uint8(0)},
{"Nullable(UInt8)", (**uint8)(nil)},
{"UInt16", uint16(0)},
{"Nullable(UInt16)", (**uint16)(nil)},
{"UInt32", uint32(0)},
{"Nullable(UInt32)", (**uint32)(nil)},
{"UInt64", uint64(0)},
{"Nullable(UInt64)", (**uint64)(nil)},
// Big integers
{"Int128", big.NewInt(0)},
{"Nullable(Int128)", (**big.Int)(nil)},
{"UInt128", big.NewInt(0)},
{"Nullable(UInt128)", (**big.Int)(nil)},
{"Int256", big.NewInt(0)},
{"Nullable(Int256)", (**big.Int)(nil)},
{"UInt256", big.NewInt(0)},
{"Nullable(UInt256)", (**big.Int)(nil)},
// Floating-point numbers
{"Float32", float32(0)},
{"Nullable(Float32)", (**float32)(nil)},
{"Float64", float64(0)},
{"Nullable(Float64)", (**float64)(nil)},
// Decimal types
{"Decimal", big.NewFloat(0)},
{"Nullable(Decimal)", (**big.Float)(nil)},
{"Decimal32", big.NewFloat(0)},
{"Nullable(Decimal32)", (**big.Float)(nil)},
{"Decimal64", big.NewFloat(0)},
{"Nullable(Decimal64)", (**big.Float)(nil)},
{"Decimal128", big.NewFloat(0)},
{"Nullable(Decimal128)", (**big.Float)(nil)},
{"Decimal256", big.NewFloat(0)},
{"Nullable(Decimal256)", (**big.Float)(nil)},
// String types
{"String", ""},
{"Nullable(String)", (**string)(nil)},
{"FixedString(42)", ""},
{"Nullable(FixedString(42))", (**string)(nil)},
{"UUID", ""},
{"Nullable(UUID)", (**string)(nil)},
{"IPv4", ""},
{"Nullable(IPv4)", (**string)(nil)},
{"IPv6", ""},
{"Nullable(IPv6)", (**string)(nil)},
// Date and time types
{"Date", time.Time{}},
{"Nullable(Date)", (**time.Time)(nil)},
{"DateTime", time.Time{}},
{"Nullable(DateTime)", (**time.Time)(nil)},
{"DateTime64", time.Time{}},
{"Nullable(DateTime64)", (**time.Time)(nil)},
// Enums
{"Enum8('a' = 1, 'b' = 2)", ""},
{"Nullable(Enum8('a' = 1, 'b' = 2))", (**string)(nil)},
{"Enum16('a' = 1, 'b' = 2)", ""},
{"Nullable(Enum16('a' = 1, 'b' = 2))", (**string)(nil)},
// Arrays
{"Array(Int32)", &[]*int64{}},
{"Array(String)", &[]*string{}},
{"Array(Float64)", &[]*float64{}},
// LowCardinality
{"LowCardinality(String)", ""},
{"LowCardinality(Nullable(String))", (**string)(nil)},
// Unknown type
{"UnknownType", new(interface{})},
{"Nullable(UnknownType)", new(interface{})},
}

for _, tc := range testCases {
t.Run(tc.dbType, func(t *testing.T) {
result := mapClickHouseTypeToGoType(tc.dbType)

expectedType := reflect.TypeOf(tc.expectedType)
resultType := reflect.TypeOf(result)

// Handle pointers
if expectedType.Kind() == reflect.Ptr {
if resultType.Kind() != reflect.Ptr {
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
return
}
expectedElemType := expectedType.Elem()
resultElemType := resultType.Elem()
if expectedElemType.Kind() == reflect.Ptr {
// Expected pointer to pointer
if resultElemType.Kind() != reflect.Ptr {
t.Errorf("Expected pointer to pointer for dbType %s, got %s", tc.dbType, resultElemType.Kind())
return
}
expectedElemType = expectedElemType.Elem()
resultElemType = resultElemType.Elem()
}
if expectedElemType != resultElemType {
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedElemType, resultElemType)
}
} else {
// Non-pointer types
if resultType.Kind() != reflect.Ptr {
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
return
}
resultElemType := resultType.Elem()
if expectedType != resultElemType {
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedType, resultElemType)
}
}
})
}
}

0 comments on commit 7698103

Please sign in to comment.