From 76981031900d1d5371108680ea698460570c35a7 Mon Sep 17 00:00:00 2001 From: catalyst17 <37663786+catalyst17@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:49:08 +0200 Subject: [PATCH] feat: dynamic resolving of aggrergates results types --- internal/storage/clickhouse.go | 206 ++++++++++++++++-- internal/storage/clickhouse_connector_test.go | 134 ++++++++++++ 2 files changed, 325 insertions(+), 15 deletions(-) create mode 100644 internal/storage/clickhouse_connector_test.go diff --git a/internal/storage/clickhouse.go b/internal/storage/clickhouse.go index 6e3d99d..b2c567c 100644 --- a/internal/storage/clickhouse.go +++ b/internal/storage/clickhouse.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "math/big" + "reflect" "strings" "sync" "time" @@ -349,24 +350,20 @@ func (c *ClickHouseConnector) GetLogs(qf QueryFilter) (QueryResult[common.Log], func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (QueryResult[interface{}], error) { // Build the SELECT clause with aggregates - columns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ") - query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", columns, c.cfg.Database, table) + selectColumns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ") + query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", selectColumns, c.cfg.Database, table) // Apply filters if qf.ChainId != nil && qf.ChainId.Sign() > 0 { query = addFilterParams("chain_id", qf.ChainId.String(), query) } query = addContractAddress(table, query, qf.ContractAddress) - if qf.Signature != "" { query += fmt.Sprintf(" AND topic_0 = '%s'", qf.Signature) } - for key, value := range qf.FilterParams { query = addFilterParams(key, strings.ToLower(value), query) } - - // Add GROUP BY clause if specified if len(qf.GroupBy) > 0 { groupByColumns := strings.Join(qf.GroupBy, ", ") query += fmt.Sprintf(" GROUP BY %s", groupByColumns) @@ -379,28 +376,45 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que } defer rows.Close() + columnNames := rows.Columns() + columnTypes := rows.ColumnTypes() + // Collect results var aggregates []map[string]interface{} for rows.Next() { - columns := rows.Columns() - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] + values := make([]interface{}, len(columnNames)) + + // Assign Go types based on ClickHouse types + for i, colType := range columnTypes { + dbType := colType.DatabaseTypeName() + values[i] = mapClickHouseTypeToGoType(dbType) } - if err := rows.Scan(valuePtrs...); err != nil { - return QueryResult[interface{}]{}, err + if err := rows.Scan(values...); err != nil { + return QueryResult[interface{}]{}, fmt.Errorf("failed to scan row: %w", err) } + // Prepare the result map for the current row result := make(map[string]interface{}) - for i, col := range columns { - result[col] = values[i] + for i, colName := range columnNames { + valuePtr := values[i] + value := getUnderlyingValue(valuePtr) + + // Convert *big.Int to string + if bigIntValue, ok := value.(big.Int); ok { + result[colName] = BigInt{Int: bigIntValue} + } else { + result[colName] = value + } } aggregates = append(aggregates, result) } + if err := rows.Err(); err != nil { + return QueryResult[interface{}]{}, fmt.Errorf("row iteration error: %w", err) + } + return QueryResult[interface{}]{Data: nil, Aggregates: aggregates}, nil } @@ -1056,3 +1070,165 @@ func (c *ClickHouseConnector) InsertBlockData(data *[]common.BlockData) error { } return nil } + +func mapClickHouseTypeToGoType(dbType string) interface{} { + // Handle LowCardinality types + if strings.HasPrefix(dbType, "LowCardinality(") { + dbType = dbType[len("LowCardinality(") : len(dbType)-1] + } + + // Handle Nullable types + isNullable := false + if strings.HasPrefix(dbType, "Nullable(") { + isNullable = true + dbType = dbType[len("Nullable(") : len(dbType)-1] + } + + // Handle Array types + if strings.HasPrefix(dbType, "Array(") { + elementType := dbType[len("Array(") : len(dbType)-1] + // For arrays, we'll use slices of pointers to the element type + switch elementType { + case "String", "FixedString": + return new([]*string) + case "Int8", "Int16", "Int32", "Int64": + return new([]*int64) + case "UInt8", "UInt16", "UInt32", "UInt64": + return new([]*uint64) + case "Float32", "Float64": + return new([]*float64) + case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256": + return new([]*big.Float) + // Add more cases as needed + default: + return new([]interface{}) + } + } + + // Handle parameterized types by extracting the base type + baseType := dbType + if idx := strings.Index(dbType, "("); idx != -1 { + baseType = dbType[:idx] + } + + // Map basic data types + switch baseType { + // Signed integers + case "Int8": + if isNullable { + return new(*int8) + } + return new(int8) + case "Int16": + if isNullable { + return new(*int16) + } + return new(int16) + case "Int32": + if isNullable { + return new(*int32) + } + return new(int32) + case "Int64": + if isNullable { + return new(*int64) + } + return new(int64) + // Unsigned integers + case "UInt8": + if isNullable { + return new(*uint8) + } + return new(uint8) + case "UInt16": + if isNullable { + return new(*uint16) + } + return new(uint16) + case "UInt32": + if isNullable { + return new(*uint32) + } + return new(uint32) + case "UInt64": + if isNullable { + return new(*uint64) + } + return new(uint64) + // Floating-point numbers + case "Float32": + if isNullable { + return new(*float32) + } + return new(float32) + case "Float64": + if isNullable { + return new(*float64) + } + return new(float64) + // Decimal types + case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256": + if isNullable { + return new(*big.Float) + } + return new(big.Float) + // String types + case "String", "FixedString", "UUID", "IPv4", "IPv6": + if isNullable { + return new(*string) + } + return new(string) + // Enums + case "Enum8", "Enum16": + if isNullable { + return new(*string) + } + return new(string) + // Date and time types + case "Date", "Date32", "DateTime", "DateTime64": + if isNullable { + return new(*time.Time) + } + return new(time.Time) + // Big integers + case "Int128", "UInt128", "Int256", "UInt256": + if isNullable { + return new(*big.Int) + } + return new(big.Int) + default: + // For unknown types, use interface{} + return new(interface{}) + } +} + +type BigInt struct { + big.Int +} + +func (b BigInt) MarshalJSON() ([]byte, error) { + return []byte(`"` + b.String() + `"`), nil +} + +func getUnderlyingValue(valuePtr interface{}) interface{} { + v := reflect.ValueOf(valuePtr) + + // Handle nil values + if !v.IsValid() { + return nil + } + + // Handle pointers and interfaces + for { + if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + if v.IsNil() { + return nil + } + v = v.Elem() + continue + } + break + } + + return v.Interface() +} diff --git a/internal/storage/clickhouse_connector_test.go b/internal/storage/clickhouse_connector_test.go new file mode 100644 index 0000000..cad416a --- /dev/null +++ b/internal/storage/clickhouse_connector_test.go @@ -0,0 +1,134 @@ +package storage + +import ( + "math/big" + "reflect" + "testing" + "time" +) + +// TestMapClickHouseTypeToGoType tests the mapClickHouseTypeToGoType function +func TestMapClickHouseTypeToGoType(t *testing.T) { + testCases := []struct { + dbType string + expectedType interface{} + }{ + // Signed integers + {"Int8", int8(0)}, + {"Nullable(Int8)", (**int8)(nil)}, + {"Int16", int16(0)}, + {"Nullable(Int16)", (**int16)(nil)}, + {"Int32", int32(0)}, + {"Nullable(Int32)", (**int32)(nil)}, + {"Int64", int64(0)}, + {"Nullable(Int64)", (**int64)(nil)}, + // Unsigned integers + {"UInt8", uint8(0)}, + {"Nullable(UInt8)", (**uint8)(nil)}, + {"UInt16", uint16(0)}, + {"Nullable(UInt16)", (**uint16)(nil)}, + {"UInt32", uint32(0)}, + {"Nullable(UInt32)", (**uint32)(nil)}, + {"UInt64", uint64(0)}, + {"Nullable(UInt64)", (**uint64)(nil)}, + // Big integers + {"Int128", big.NewInt(0)}, + {"Nullable(Int128)", (**big.Int)(nil)}, + {"UInt128", big.NewInt(0)}, + {"Nullable(UInt128)", (**big.Int)(nil)}, + {"Int256", big.NewInt(0)}, + {"Nullable(Int256)", (**big.Int)(nil)}, + {"UInt256", big.NewInt(0)}, + {"Nullable(UInt256)", (**big.Int)(nil)}, + // Floating-point numbers + {"Float32", float32(0)}, + {"Nullable(Float32)", (**float32)(nil)}, + {"Float64", float64(0)}, + {"Nullable(Float64)", (**float64)(nil)}, + // Decimal types + {"Decimal", big.NewFloat(0)}, + {"Nullable(Decimal)", (**big.Float)(nil)}, + {"Decimal32", big.NewFloat(0)}, + {"Nullable(Decimal32)", (**big.Float)(nil)}, + {"Decimal64", big.NewFloat(0)}, + {"Nullable(Decimal64)", (**big.Float)(nil)}, + {"Decimal128", big.NewFloat(0)}, + {"Nullable(Decimal128)", (**big.Float)(nil)}, + {"Decimal256", big.NewFloat(0)}, + {"Nullable(Decimal256)", (**big.Float)(nil)}, + // String types + {"String", ""}, + {"Nullable(String)", (**string)(nil)}, + {"FixedString(42)", ""}, + {"Nullable(FixedString(42))", (**string)(nil)}, + {"UUID", ""}, + {"Nullable(UUID)", (**string)(nil)}, + {"IPv4", ""}, + {"Nullable(IPv4)", (**string)(nil)}, + {"IPv6", ""}, + {"Nullable(IPv6)", (**string)(nil)}, + // Date and time types + {"Date", time.Time{}}, + {"Nullable(Date)", (**time.Time)(nil)}, + {"DateTime", time.Time{}}, + {"Nullable(DateTime)", (**time.Time)(nil)}, + {"DateTime64", time.Time{}}, + {"Nullable(DateTime64)", (**time.Time)(nil)}, + // Enums + {"Enum8('a' = 1, 'b' = 2)", ""}, + {"Nullable(Enum8('a' = 1, 'b' = 2))", (**string)(nil)}, + {"Enum16('a' = 1, 'b' = 2)", ""}, + {"Nullable(Enum16('a' = 1, 'b' = 2))", (**string)(nil)}, + // Arrays + {"Array(Int32)", &[]*int64{}}, + {"Array(String)", &[]*string{}}, + {"Array(Float64)", &[]*float64{}}, + // LowCardinality + {"LowCardinality(String)", ""}, + {"LowCardinality(Nullable(String))", (**string)(nil)}, + // Unknown type + {"UnknownType", new(interface{})}, + {"Nullable(UnknownType)", new(interface{})}, + } + + for _, tc := range testCases { + t.Run(tc.dbType, func(t *testing.T) { + result := mapClickHouseTypeToGoType(tc.dbType) + + expectedType := reflect.TypeOf(tc.expectedType) + resultType := reflect.TypeOf(result) + + // Handle pointers + if expectedType.Kind() == reflect.Ptr { + if resultType.Kind() != reflect.Ptr { + t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind()) + return + } + expectedElemType := expectedType.Elem() + resultElemType := resultType.Elem() + if expectedElemType.Kind() == reflect.Ptr { + // Expected pointer to pointer + if resultElemType.Kind() != reflect.Ptr { + t.Errorf("Expected pointer to pointer for dbType %s, got %s", tc.dbType, resultElemType.Kind()) + return + } + expectedElemType = expectedElemType.Elem() + resultElemType = resultElemType.Elem() + } + if expectedElemType != resultElemType { + t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedElemType, resultElemType) + } + } else { + // Non-pointer types + if resultType.Kind() != reflect.Ptr { + t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind()) + return + } + resultElemType := resultType.Elem() + if expectedType != resultElemType { + t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedType, resultElemType) + } + } + }) + } +}