From 3a295547f478bda564c3943b5e1f5d61a525437b Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Tue, 10 Oct 2023 11:58:53 +0200 Subject: [PATCH] SNOW-911238: Add WithOriginalTimestamp context switch (#918) SNOW-911238: Add WithOriginalTimestamp context switch --- arrow_chunk.go | 2 +- chunk_downloader.go | 3 + cmd/arrow/batches/arrow_batches.go | 19 +- connection.go | 2 +- converter.go | 295 ++++++++-------- converter_test.go | 521 +++++++++++++++++++++++++---- datatype.go | 3 +- errors.go | 2 + rows.go | 2 +- util.go | 28 +- 10 files changed, 645 insertions(+), 232 deletions(-) diff --git a/arrow_chunk.go b/arrow_chunk.go index 344774af8..15851a80f 100644 --- a/arrow_chunk.go +++ b/arrow_chunk.go @@ -57,7 +57,7 @@ func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[ for arc.reader.Next() { rawRecord := arc.reader.Record() - record, err := arrowToRecord(rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc) + record, err := arrowToRecord(scd.ctx, rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc) if err != nil { return nil, err } diff --git a/chunk_downloader.go b/chunk_downloader.go index a32fd1628..6806e5e32 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -279,6 +279,7 @@ func (scd *snowflakeChunkDownloader) startArrowBatches() error { idx: 0, scd: scd, funcDownloadHelper: scd.FuncDownloadHelper, + loc: loc, } // decode first chunk if possible if firstArrowChunk.allocator != nil { @@ -293,6 +294,7 @@ func (scd *snowflakeChunkDownloader) startArrowBatches() error { idx: i, scd: scd, funcDownloadHelper: scd.FuncDownloadHelper, + loc: loc, } } return nil @@ -708,6 +710,7 @@ type ArrowBatch struct { scd *snowflakeChunkDownloader funcDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error ctx context.Context + loc *time.Location } // WithContext sets the context which will be used for this ArrowBatch. diff --git a/cmd/arrow/batches/arrow_batches.go b/cmd/arrow/batches/arrow_batches.go index 90ab7a698..198d612df 100644 --- a/cmd/arrow/batches/arrow_batches.go +++ b/cmd/arrow/batches/arrow_batches.go @@ -11,6 +11,7 @@ import ( "github.com/apache/arrow/go/v12/arrow/memory" "log" "sync" + "time" sf "github.com/snowflakedb/gosnowflake" ) @@ -20,10 +21,11 @@ type sampleRecord struct { workerID int number int32 string string + ts *time.Time } func (s sampleRecord) String() string { - return fmt.Sprintf("batchID: %v, workerID: %v, number: %v, string: %v", s.batchID, s.workerID, s.number, s.string) + return fmt.Sprintf("batchID: %v, workerID: %v, number: %v, string: %v, ts: %v", s.batchID, s.workerID, s.number, s.string, s.ts) } func main() { @@ -48,8 +50,14 @@ func main() { log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) } - ctx := sf.WithArrowAllocator(sf.WithArrowBatches(context.Background()), memory.DefaultAllocator) - query := "SELECT SEQ4(), 'example ' || (SEQ4() * 2) FROM TABLE(GENERATOR(ROWCOUNT=>30000))" + ctx := + sf.WithOriginalTimestamp( + sf.WithArrowAllocator( + sf.WithArrowBatches(context.Background()), memory.DefaultAllocator)) + + query := "SELECT SEQ4(), 'example ' || (SEQ4() * 2), " + + " TO_TIMESTAMP_NTZ('9999-01-01 13:13:13.' || LPAD(SEQ4(),9,'0')) ltz " + + " FROM TABLE(GENERATOR(ROWCOUNT=>30000))" db, err := sql.Open("snowflake", dsn) if err != nil { @@ -88,7 +96,7 @@ func main() { } sampleRecordsPerBatch[batchID] = make([]sampleRecord, batches[batchID].GetRowCount()) totalRowID := 0 - convertFromColumnsToRows(records, sampleRecordsPerBatch, batchID, workerId, totalRowID) + convertFromColumnsToRows(records, sampleRecordsPerBatch, batchID, workerId, totalRowID, batches[batchID]) } }(&waitGroup, batchIds, workerID) } @@ -110,7 +118,7 @@ func main() { } func convertFromColumnsToRows(records *[]arrow.Record, sampleRecordsPerBatch [][]sampleRecord, batchID int, - workerID int, totalRowID int) { + workerID int, totalRowID int, batch *sf.ArrowBatch) { for _, record := range *records { for rowID, intColumn := range record.Column(0).(*array.Int32).Int32Values() { sampleRecord := sampleRecord{ @@ -118,6 +126,7 @@ func convertFromColumnsToRows(records *[]arrow.Record, sampleRecordsPerBatch [][ workerID: workerID, number: intColumn, string: record.Column(1).(*array.String).Value(rowID), + ts: batch.ArrowSnowflakeTimestampToTime(record, 2, rowID), } sampleRecordsPerBatch[batchID][totalRowID] = sampleRecord totalRowID++ diff --git a/connection.go b/connection.go index 769253d46..c9d760327 100644 --- a/connection.go +++ b/connection.go @@ -409,7 +409,7 @@ func (sc *snowflakeConn) queryContextInternal( rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data)) } - rows.ChunkDownloader.start() + err = rows.ChunkDownloader.start() return rows, err } diff --git a/converter.go b/converter.go index 44c0afbfa..88f64baa6 100644 --- a/converter.go +++ b/converter.go @@ -354,6 +354,79 @@ func decimalToBigFloat(num decimal128.Num, scale int64) *big.Float { return new(big.Float).Quo(f, s) } +// ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time +func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time { + scale := int(rb.scd.RowSet.RowType[colIdx].Scale) + dbType := rb.scd.RowSet.RowType[colIdx].Type + return arrowSnowflakeTimestampToTime(rec.Column(colIdx), getSnowflakeType(dbType), scale, recIdx, rb.loc) +} + +func arrowSnowflakeTimestampToTime( + column arrow.Array, + sfType snowflakeType, + scale int, + recIdx int, + loc *time.Location) *time.Time { + + if column.IsNull(recIdx) { + return nil + } + var ret time.Time + switch sfType { + case timestampNtzType: + if column.DataType().ID() == arrow.STRUCT { + structData := column.(*array.Struct) + epoch := structData.Field(0).(*array.Int64).Int64Values() + fraction := structData.Field(1).(*array.Int32).Int32Values() + ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).UTC() + } else { + intData := column.(*array.Int64) + value := intData.Value(recIdx) + epoch := extractEpoch(value, scale) + fraction := extractFraction(value, scale) + ret = time.Unix(epoch, fraction).UTC() + } + case timestampLtzType: + if column.DataType().ID() == arrow.STRUCT { + structData := column.(*array.Struct) + epoch := structData.Field(0).(*array.Int64).Int64Values() + fraction := structData.Field(1).(*array.Int32).Int32Values() + ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(loc) + } else { + intData := column.(*array.Int64) + value := intData.Value(recIdx) + epoch := extractEpoch(value, scale) + fraction := extractFraction(value, scale) + ret = time.Unix(epoch, fraction).In(loc) + } + case timestampTzType: + structData := column.(*array.Struct) + if structData.NumField() == 2 { + value := structData.Field(0).(*array.Int64).Int64Values() + timezone := structData.Field(1).(*array.Int32).Int32Values() + epoch := extractEpoch(value[recIdx], scale) + fraction := extractFraction(value[recIdx], scale) + locTz := Location(int(timezone[recIdx]) - 1440) + ret = time.Unix(epoch, fraction).In(locTz) + } else { + epoch := structData.Field(0).(*array.Int64).Int64Values() + fraction := structData.Field(1).(*array.Int32).Int32Values() + timezone := structData.Field(2).(*array.Int32).Int32Values() + locTz := Location(int(timezone[recIdx]) - 1440) + ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(locTz) + } + } + return &ret +} + +func extractEpoch(value int64, scale int) int64 { + return value / int64(math.Pow10(scale)) +} + +func extractFraction(value int64, scale int) int64 { + return (value % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale)) +} + // Arrow Interface (Column) converter. This is called when Arrow chunks are // downloaded to convert to the corresponding row type. func arrowToValue( @@ -369,7 +442,8 @@ func arrowToValue( } logger.Debugf("snowflake data type: %v, arrow data type: %v", srcColumnMeta.Type, srcValue.DataType()) - switch getSnowflakeType(strings.ToUpper(srcColumnMeta.Type)) { + snowflakeType := getSnowflakeType(srcColumnMeta.Type) + switch snowflakeType { case fixedType: // Snowflake data types that are fixed-point numbers will fall into this category // e.g. NUMBER, DECIMAL/NUMERIC, INT/INTEGER @@ -528,69 +602,11 @@ func arrowToValue( } } return err - case timestampNtzType: - if srcValue.DataType().ID() == arrow.STRUCT { - structData := srcValue.(*array.Struct) - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - for i := range destcol { - if !srcValue.IsNull(i) { - destcol[i] = time.Unix(epoch[i], int64(fraction[i])).UTC() - } - } - } else { - for i, t := range srcValue.(*array.Int64).Int64Values() { - if !srcValue.IsNull(i) { - scale := int(srcColumnMeta.Scale) - epoch := t / int64(math.Pow10(scale)) - fraction := (t % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale)) - destcol[i] = time.Unix(epoch, fraction).UTC() - } - } - } - return err - case timestampLtzType: - if srcValue.DataType().ID() == arrow.STRUCT { - structData := srcValue.(*array.Struct) - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - for i := range destcol { - if !srcValue.IsNull(i) { - destcol[i] = time.Unix(epoch[i], int64(fraction[i])).In(loc) - } - } - } else { - for i, t := range srcValue.(*array.Int64).Int64Values() { - if !srcValue.IsNull(i) { - q := t / int64(math.Pow10(int(srcColumnMeta.Scale))) - r := t % int64(math.Pow10(int(srcColumnMeta.Scale))) - destcol[i] = time.Unix(q, r).In(loc) - } - } - } - return err - case timestampTzType: - structData := srcValue.(*array.Struct) - if structData.NumField() == 2 { - epoch := structData.Field(0).(*array.Int64).Int64Values() - timezone := structData.Field(1).(*array.Int32).Int32Values() - for i := range destcol { - if !srcValue.IsNull(i) { - loc := Location(int(timezone[i]) - 1440) - tt := time.Unix(epoch[i], 0) - destcol[i] = tt.In(loc) - } - } - } else { - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - timezone := structData.Field(2).(*array.Int32).Int32Values() - for i := range destcol { - if !srcValue.IsNull(i) { - loc := Location(int(timezone[i]) - 1440) - tt := time.Unix(epoch[i], int64(fraction[i])) - destcol[i] = tt.In(loc) - } + case timestampNtzType, timestampLtzType, timestampTzType: + for i := range destcol { + var ts = arrowSnowflakeTimestampToTime(srcValue, snowflakeType, int(srcColumnMeta.Scale), i, loc) + if ts != nil { + destcol[i] = *ts } } return err @@ -952,22 +968,34 @@ func higherPrecisionEnabled(ctx context.Context) bool { return ok && d } -func arrowToRecord(record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { - s, err := recordToSchema(record.Schema(), rowType, loc) +func originalTimestampEnabled(ctx context.Context) bool { + v := ctx.Value(enableOriginalTimestamp) + if v == nil { + return false + } + d, ok := v.(bool) + return ok && d +} + +func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { + useOriginalTimestamp := originalTimestampEnabled(ctx) + + s, err := recordToSchema(record.Schema(), rowType, loc, useOriginalTimestamp) if err != nil { return nil, err } var cols []arrow.Array numRows := record.NumRows() - ctx := compute.WithAllocator(context.Background(), pool) + ctxAlloc := compute.WithAllocator(ctx, pool) for i, col := range record.Columns() { srcColumnMeta := rowType[i] // TODO: confirm that it is okay to be using higher precision logic for conversions newCol := col - switch getSnowflakeType(strings.ToUpper(srcColumnMeta.Type)) { + snowflakeType := getSnowflakeType(srcColumnMeta.Type) + switch snowflakeType { case fixedType: var toType arrow.DataType if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 { @@ -978,13 +1006,13 @@ func arrowToRecord(record arrow.Record, pool memory.Allocator, rowType []execRes } // we're fine truncating so no error for data loss here. // so we use UnsafeCastOptions. - newCol, err = compute.CastArray(ctx, col, compute.UnsafeCastOptions(toType)) + newCol, err = compute.CastArray(ctxAlloc, col, compute.UnsafeCastOptions(toType)) if err != nil { return nil, err } defer newCol.Release() } else if srcColumnMeta.Scale != 0 { - result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, + result, err := compute.Divide(ctxAlloc, compute.ArithmeticOptions{NoCheckOverflow: true}, &compute.ArrayDatum{Value: newCol.Data()}, compute.NewDatum(math.Pow10(int(srcColumnMeta.Scale)))) if err != nil { @@ -995,108 +1023,51 @@ func arrowToRecord(record arrow.Record, pool memory.Allocator, rowType []execRes defer newCol.Release() } case timeType: - newCol, err = compute.CastArray(ctx, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) + newCol, err = compute.CastArray(ctxAlloc, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) if err != nil { return nil, err } defer newCol.Release() - case timestampNtzType: - tb := array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: arrow.Nanosecond}) - if col.DataType().ID() == arrow.STRUCT { - structData := col.(*array.Struct) - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - for i := 0; i < int(numRows); i++ { - if !col.IsNull(i) { - val := time.Unix(epoch[i], int64(fraction[i])) - tb.Append(arrow.Timestamp(val.UnixNano())) - } else { - tb.AppendNull() - } - } + case timestampNtzType, timestampLtzType, timestampTzType: + if useOriginalTimestamp { + // do nothing - return timestamp as is } else { - for i, t := range col.(*array.Timestamp).TimestampValues() { - if !col.IsNull(i) { - val := time.Unix(0, int64(t)*int64(math.Pow10(9-int(srcColumnMeta.Scale)))).UTC() - tb.Append(arrow.Timestamp(val.UnixNano())) - } else { - tb.AppendNull() - } - } - } - newCol = tb.NewArray() - defer newCol.Release() - tb.Release() - case timestampLtzType: - tb := array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()}) - if col.DataType().ID() == arrow.STRUCT { - structData := col.(*array.Struct) - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - for i := 0; i < int(numRows); i++ { - if !col.IsNull(i) { - val := time.Unix(epoch[i], int64(fraction[i])) - tb.Append(arrow.Timestamp(val.UnixNano())) - } else { - tb.AppendNull() - } - } - } else { - for i, t := range col.(*array.Timestamp).TimestampValues() { - if !col.IsNull(i) { - q := int64(t) / int64(math.Pow10(int(srcColumnMeta.Scale))) - r := int64(t) % int64(math.Pow10(int(srcColumnMeta.Scale))) - val := time.Unix(q, r) - tb.Append(arrow.Timestamp(val.UnixNano())) - } else { - tb.AppendNull() - } - } - } - newCol = tb.NewArray() - defer newCol.Release() - tb.Release() - case timestampTzType: - tb := array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: arrow.Nanosecond}) - structData := col.(*array.Struct) - if structData.NumField() == 2 { - epoch := structData.Field(0).(*array.Int64).Int64Values() - timezone := structData.Field(1).(*array.Int32).Int32Values() - for i := 0; i < int(numRows); i++ { - if !col.IsNull(i) { - loc := Location(int(timezone[i]) - 1440) - tt := time.Unix(epoch[i], 0) - val := tt.In(loc) - tb.Append(arrow.Timestamp(val.UnixNano())) - } else { - tb.AppendNull() - } + var tb *array.TimestampBuilder + if snowflakeType == timestampLtzType { + tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()}) + } else { + tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: arrow.Nanosecond}) } - } else { - epoch := structData.Field(0).(*array.Int64).Int64Values() - fraction := structData.Field(1).(*array.Int32).Int32Values() - timezone := structData.Field(2).(*array.Int32).Int32Values() + defer tb.Release() + for i := 0; i < int(numRows); i++ { - if !col.IsNull(i) { - loc := Location(int(timezone[i]) - 1440) - tt := time.Unix(epoch[i], int64(fraction[i])) - val := tt.In(loc) - tb.Append(arrow.Timestamp(val.UnixNano())) + ts := arrowSnowflakeTimestampToTime(col, snowflakeType, int(srcColumnMeta.Scale), i, loc) + if ts != nil { + ar := arrow.Timestamp(ts.UnixNano()) + // in case of overflow in arrow timestamp return error + if ts.Year() != ar.ToTime(arrow.Nanosecond).Year() { + return nil, &SnowflakeError{ + Number: ErrTooHighTimestampPrecision, + SQLState: SQLStateInvalidDataTimeFormat, + Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), srcColumnMeta.Name), + } + } + tb.Append(ar) } else { tb.AppendNull() } } + + newCol = tb.NewArray() + defer newCol.Release() } - newCol = tb.NewArray() - defer newCol.Release() - tb.Release() } cols = append(cols, newCol) } return array.NewRecord(s, cols, numRows), nil } -func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location) (*arrow.Schema, error) { +func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, useOriginalTimestamp bool) (*arrow.Schema, error) { var fields []arrow.Field for i := 0; i < len(sc.Fields()); i++ { f := sc.Field(i) @@ -1104,7 +1075,7 @@ func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.L converted := true var t arrow.DataType - switch getSnowflakeType(strings.ToUpper(srcColumnMeta.Type)) { + switch getSnowflakeType(srcColumnMeta.Type) { case fixedType: switch f.Type.ID() { case arrow.DECIMAL: @@ -1123,9 +1094,19 @@ func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.L case timeType: t = &arrow.Time64Type{Unit: arrow.Nanosecond} case timestampNtzType, timestampTzType: - t = &arrow.TimestampType{Unit: arrow.Nanosecond} + if useOriginalTimestamp { + // do nothing - return timestamp as is + converted = false + } else { + t = &arrow.TimestampType{Unit: arrow.Nanosecond} + } case timestampLtzType: - t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} + if useOriginalTimestamp { + // do nothing - return timestamp as is + converted = false + } else { + t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} + } default: converted = false } diff --git a/converter_test.go b/converter_test.go index 6a43404f2..98dda7aa2 100644 --- a/converter_test.go +++ b/converter_test.go @@ -12,6 +12,7 @@ import ( "math/big" "math/cmplx" "reflect" + "strconv" "strings" "testing" "time" @@ -828,10 +829,15 @@ func TestArrowToRecord(t *testing.T) { var valids []bool // AppendValues() with an empty valid array adds every value by default localTime := time.Date(2019, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) + localTimeFarIntoFuture := time.Date(9000, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) - field1 := arrow.Field{Name: "epoch", Type: &arrow.Int64Type{}} - field2 := arrow.Field{Name: "timezone", Type: &arrow.Int32Type{}} - tzStruct := arrow.StructOf(field1, field2) + epochField := arrow.Field{Name: "epoch", Type: &arrow.Int64Type{}} + timezoneField := arrow.Field{Name: "timezone", Type: &arrow.Int32Type{}} + fractionField := arrow.Field{Name: "fraction", Type: &arrow.Int32Type{}} + timestampTzStructWithoutFraction := arrow.StructOf(epochField, timezoneField) + timestampTzStructWithFraction := arrow.StructOf(epochField, fractionField, timezoneField) + timestampNtzStruct := arrow.StructOf(epochField, fractionField) + timestampLtzStruct := arrow.StructOf(epochField, fractionField) type testObj struct { field1 int @@ -844,6 +850,8 @@ func TestArrowToRecord(t *testing.T) { sc *arrow.Schema rowType execResponseRowType values interface{} + error string + origTS bool nrows int builder array.Builder append func(b array.Builder, vs interface{}) @@ -1106,21 +1114,22 @@ func TestArrowToRecord(t *testing.T) { }, }, { - logical: "timestamp_ntz", - values: []time.Time{time.Now(), localTime}, - nrows: 2, - rowType: execResponseRowType{Scale: 9}, - sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.TimestampType{}}}, nil), - builder: array.NewTimestampBuilder(pool, &arrow.TimestampType{}), + logical: "timestamp_ntz", + physical: "int64", // timestamp_ntz with scale 0..3 -> int64 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, // Millisecond for scale = 3 + nrows: 2, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs interface{}) { for _, t := range vs.([]time.Time) { - b.(*array.TimestampBuilder).Append(arrow.Timestamp(t.UnixNano())) + b.(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 } }, compare: func(src interface{}, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { - if srcvs[i].UnixNano() != int64(t) { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } @@ -1128,21 +1137,26 @@ func TestArrowToRecord(t *testing.T) { }, }, { - logical: "timestamp_ltz", - values: []time.Time{time.Now(), localTime}, - nrows: 2, - rowType: execResponseRowType{Scale: 9}, - sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.TimestampType{}}}, nil), - builder: array.NewTimestampBuilder(pool, &arrow.TimestampType{}), + logical: "timestamp_ntz", + physical: "struct", // timestamp_ntz with scale 4..9 -> int64 + int32 + values: []time.Time{time.Now(), localTime}, + nrows: 2, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), + builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true} + sb.AppendValues(valids) for _, t := range vs.([]time.Time) { - b.(*array.TimestampBuilder).Append(arrow.Timestamp(t.UnixNano())) + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src interface{}, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { - if srcvs[i].UnixNano() != int64(t) { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } @@ -1150,24 +1164,303 @@ func TestArrowToRecord(t *testing.T) { }, }, { - logical: "timestamp_tz", - values: []time.Time{time.Now(), localTime}, - nrows: 2, - sc: arrow.NewSchema([]arrow.Field{{Type: arrow.StructOf(field1, field2)}}, nil), - builder: array.NewStructBuilder(pool, tzStruct), + logical: "timestamp_ntz", + physical: "error", + values: []time.Time{localTimeFarIntoFuture}, + error: "Cannot convert timestamp", + nrows: 1, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, t := range vs.([]time.Time) { + b.(*array.Int64Builder).Append(t.UnixMilli()) + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { return 0 }, + }, + { + logical: "timestamp_ntz", + physical: "int64 with original timestamp", // timestamp_ntz with scale 0..3 -> int64 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, t := range vs.([]time.Time) { + b.(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampNtzType, 3, i, nil) + if !srcvs[i].Equal(*ts) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_ntz", + physical: "struct with original timestamp", // timestamp_ntz with scale 4..9 -> int64 + int32 + values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), + builder: array.NewStructBuilder(pool, timestampNtzStruct), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampNtzType, 9, i, nil) + if !srcvs[i].Equal(*ts) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_ltz", + physical: "int64", // timestamp_ntz with scale 0..3 -> int64 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, + nrows: 2, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, t := range vs.([]time.Time) { + b.(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_ltz", + physical: "struct", // timestamp_ntz with scale 4..9 -> int64 + int32 + values: []time.Time{time.Now(), localTime}, + nrows: 2, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), + builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs interface{}) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) - sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.UnixNano())) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src interface{}, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { - if srcvs[i].Unix() != time.Unix(0, int64(t)).Unix() { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_ltz", + physical: "error", + values: []time.Time{localTimeFarIntoFuture}, + error: "Cannot convert timestamp", + nrows: 1, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, t := range vs.([]time.Time) { + b.(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { return 0 }, + }, + { + logical: "timestamp_ltz", + physical: "int64 with original timestamp", // timestamp_ntz with scale 0..3 -> int64 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, t := range vs.([]time.Time) { + b.(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampLtzType, 3, i, localTime.Location()) + if !srcvs[i].Equal(*ts) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_ltz", + physical: "struct with original timestamp", // timestamp_ntz with scale 4..9 -> int64 + int32 + values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampLtzStruct}}, nil), + builder: array.NewStructBuilder(pool, timestampLtzStruct), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampLtzType, 9, i, localTime.Location()) + if !srcvs[i].Equal(*ts) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_tz", + physical: "struct2", // timestamp_tz with scale 0..3 -> int64 + int32 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, + nrows: 2, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil), + builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0)) // timezone index - not important in tests + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_tz", + physical: "struct3", // timestamp_tz with scale 4..9 -> int64 + int32 + int32 + values: []time.Time{time.Now(), localTime}, + nrows: 2, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), + builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) + sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) // timezone index - not important in tests + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { + if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_tz", + physical: "struct2 with original timestamp", // timestamp_ntz with scale 0..3 -> int64 + int32 + values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 3}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil), + builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli()) // Millisecond for scale = 3 + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0)) // timezone index - not important in tests + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampTzType, 3, i, nil) + if !srcvs[i].Equal(*ts) { + return i + } + } + return -1 + }, + }, + { + logical: "timestamp_tz", + physical: "struct3 with original timestamp", // timestamp_ntz with scale 4..9 -> int64 + int32 + int32 + values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, + origTS: true, + nrows: 3, + rowType: execResponseRowType{Scale: 9}, + sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), + builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), + append: func(b array.Builder, vs interface{}) { + sb := b.(*array.StructBuilder) + valids = []bool{true, true, true} + sb.AppendValues(valids) + for _, t := range vs.([]time.Time) { + sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) + sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) + sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) // timezone index - not important in tests + } + }, + compare: func(src interface{}, convertedRec arrow.Record) int { + srcvs := src.([]time.Time) + for i := 0; i < convertedRec.Column(0).Len(); i++ { + ts := arrowSnowflakeTimestampToTime(convertedRec.Column(0), timestampTzType, 9, i, nil) + if !srcvs[i].Equal(*ts) { return i } } @@ -1218,22 +1511,33 @@ func TestArrowToRecord(t *testing.T) { meta := tc.rowType meta.Type = tc.logical - transformedRec, err := arrowToRecord(rawRec, pool, []execResponseRowType{meta}, localTime.Location()) - if err != nil { - t.Fatalf("error: %s", err) + ctx := context.Background() + if tc.origTS { + ctx = WithOriginalTimestamp(ctx) } - defer transformedRec.Release() - if tc.compare != nil { - idx := tc.compare(tc.values, transformedRec) - if idx != -1 { - t.Fatalf("error: column array value mismatch at index %v", idx) + transformedRec, err := arrowToRecord(ctx, rawRec, pool, []execResponseRowType{meta}, localTime.Location()) + if err != nil { + if tc.error == "" || !strings.Contains(err.Error(), tc.error) { + t.Fatalf("error: %s", err) } } else { - for i, c := range transformedRec.Columns() { - rawCol := rawRec.Column(i) - if rawCol != c { - t.Fatalf("error: expected column %s, got column %s", rawCol, c) + defer transformedRec.Release() + if tc.error != "" { + t.Fatalf("expected error: %s", tc.error) + } + + if tc.compare != nil { + idx := tc.compare(tc.values, transformedRec) + if idx != -1 { + t.Fatalf("error: column array value mismatch at index %v", idx) + } + } else { + for i, c := range transformedRec.Columns() { + rawCol := rawRec.Column(i) + if rawCol != c { + t.Fatalf("error: expected column %s, got column %s", rawCol, c) + } } } } @@ -1299,30 +1603,135 @@ func TestSmallTimestampBinding(t *testing.T) { }) } -func TestLargeTimestampBinding(t *testing.T) { - runSnowflakeConnTest(t, func(sct *SCTest) { +func TestTimestampConversionWithoutArrowBatches(t *testing.T) { + timestamps := [3]string{ + "2000-10-10 10:10:10.123456789", // neutral + "9999-12-12 23:59:59.999999999", // max + "0001-01-01 00:00:00.000000000"} // min + types := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} + + runDBTest(t, func(sct *DBTest) { ctx := context.Background() - timeValue, err := time.Parse("2006-01-02 15:04:05", "9000-10-10 10:10:10") - if err != nil { - t.Fatalf("failed to parse time: %v", err) + + for _, tsStr := range timestamps { + ts, err := time.Parse("2006-01-02 15:04:05", tsStr) + if err != nil { + t.Fatalf("failed to parse time: %v", err) + } + for _, tp := range types { + for scale := 0; scale <= 9; scale++ { + t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) { + query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) + rows := sct.mustQueryContext(ctx, query, nil) + defer rows.Close() + + if rows.Next() { + var act time.Time + rows.Scan(&act) + exp := ts.Truncate(time.Duration(math.Pow10(9 - scale))) + if !exp.Equal(act) { + t.Fatalf("unexpected result. expected: %v, got: %v", exp, act) + } + } else { + t.Fatalf("failed to run query: %v", query) + } + }) + } + } } - parameters := []driver.NamedValue{ - {Ordinal: 1, Value: DataTypeTimestampNtz}, - {Ordinal: 2, Value: timeValue}, + }) +} + +func TestTimestampConversionWithArrowBatchesFailsForDistantDates(t *testing.T) { + timestamps := [2]string{ + "9999-12-12 23:59:59.999999999", // max + "0001-01-01 00:00:00.000000000"} // min + types := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} + + expectedError := "Cannot convert timestamp" + + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithArrowBatches(sct.sc.ctx) + + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) + + for _, tsStr := range timestamps { + for _, tp := range types { + for scale := 0; scale <= 9; scale++ { + t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) { + + query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) + _, err := sct.sc.QueryContext(ctx, query, []driver.NamedValue{}) + if err != nil { + if !strings.Contains(err.Error(), expectedError) { + t.Fatalf("improper error, expected: %v, got: %v", expectedError, err.Error()) + } + } else { + t.Fatalf("no error, expected: %v ", expectedError) + + } + }) + } + } } + }) +} - rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) - defer rows.Close() +func TestTimestampConversionWithArrowBatchesAndWithOriginalTimestamp(t *testing.T) { + timestamps := [3]string{ + "2000-10-10 10:10:10.123456789", // neutral + "9999-12-12 23:59:59.999999999", // max + "0001-01-01 00:00:00.000000000"} // min + types := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} - scanValues := make([]driver.Value, 1) - for { - if err := rows.Next(scanValues); err == io.EOF { - break - } else if err != nil { - t.Fatalf("failed to run query: %v", err) + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithOriginalTimestamp(WithArrowBatches(sct.sc.ctx)) + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) + + for _, tsStr := range timestamps { + ts, err := time.Parse("2006-01-02 15:04:05", tsStr) + if err != nil { + t.Fatalf("failed to parse time: %v", err) } - if scanValues[0] != timeValue { - t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + for _, tp := range types { + for scale := 0; scale <= 9; scale++ { + t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) { + + query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) + rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) + defer rows.Close() + + // getting result batches + batches, err := rows.(*snowflakeRows).GetArrowBatches() + if err != nil { + t.Error(err) + } + + numBatches := len(batches) + if numBatches != 1 { + t.Errorf("incorrect number of batches, expected: 1, got: %v", numBatches) + } + + rec, err := batches[0].Fetch() + if err != nil { + t.Error(err) + } + exp := ts.Truncate(time.Duration(math.Pow10(9 - scale))) + for _, r := range *rec { + defer r.Release() + act := batches[0].ArrowSnowflakeTimestampToTime(r, 0, 0) + if act == nil { + t.Fatalf("unexpected result. expected: %v, got: nil", exp) + } else if !exp.Equal(*act) { + t.Fatalf("unexpected result. expected: %v, got: %v", exp, act) + } + } + }) + } } } }) diff --git a/datatype.go b/datatype.go index 73fb91499..db61a90e9 100644 --- a/datatype.go +++ b/datatype.go @@ -7,6 +7,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "strings" ) type snowflakeType int @@ -73,7 +74,7 @@ func (st snowflakeType) String() string { } func getSnowflakeType(typ string) snowflakeType { - return snowflakeToDriverType[typ] + return snowflakeToDriverType[strings.ToUpper(typ)] } var ( diff --git a/errors.go b/errors.go index a4104f66c..64a77ac5f 100644 --- a/errors.go +++ b/errors.go @@ -216,6 +216,8 @@ const ( ErrInvalidOffsetStr = 268001 // ErrInvalidBinaryHexForm is an error code for the case where a binary data in hex form is invalid. ErrInvalidBinaryHexForm = 268002 + // ErrTooHighTimestampPrecision is an error code for the case where cannot convert Snowflake timestamp to arrow.Timestamp + ErrTooHighTimestampPrecision = 268003 /* OCSP */ diff --git a/rows.go b/rows.go index 83f49ba94..3d3fcbb0f 100644 --- a/rows.go +++ b/rows.go @@ -148,7 +148,7 @@ func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type { return nil } return snowflakeTypeToGo( - getSnowflakeType(strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type)), + getSnowflakeType(rows.ChunkDownloader.getRowType()[index].Type), rows.ChunkDownloader.getRowType()[index].Scale) } diff --git a/util.go b/util.go index ade109364..96f2e9bb5 100644 --- a/util.go +++ b/util.go @@ -18,16 +18,17 @@ import ( type contextKey string const ( - multiStatementCount contextKey = "MULTI_STATEMENT_COUNT" - asyncMode contextKey = "ASYNC_MODE_QUERY" - queryIDChannel contextKey = "QUERY_ID_CHANNEL" - snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID" - fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID" - fileStreamFile contextKey = "STREAMING_PUT_FILE" - fileTransferOptions contextKey = "FILE_TRANSFER_OPTIONS" - enableHigherPrecision contextKey = "ENABLE_HIGHER_PRECISION" - arrowBatches contextKey = "ARROW_BATCHES" - arrowAlloc contextKey = "ARROW_ALLOC" + multiStatementCount contextKey = "MULTI_STATEMENT_COUNT" + asyncMode contextKey = "ASYNC_MODE_QUERY" + queryIDChannel contextKey = "QUERY_ID_CHANNEL" + snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID" + fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID" + fileStreamFile contextKey = "STREAMING_PUT_FILE" + fileTransferOptions contextKey = "FILE_TRANSFER_OPTIONS" + enableHigherPrecision contextKey = "ENABLE_HIGHER_PRECISION" + arrowBatches contextKey = "ARROW_BATCHES" + arrowAlloc contextKey = "ARROW_ALLOC" + enableOriginalTimestamp contextKey = "ENABLE_ORIGINAL_TIMESTAMP" ) const ( @@ -105,6 +106,13 @@ func WithArrowAllocator(ctx context.Context, pool memory.Allocator) context.Cont return context.WithValue(ctx, arrowAlloc, pool) } +// WithOriginalTimestamp in combination with WithArrowBatches returns a context +// that allows users to retrieve arrow.Record with original timestamp struct returned by Snowflake. +// It can be used in case arrow.Timestamp cannot fit original timestamp values. +func WithOriginalTimestamp(ctx context.Context) context.Context { + return context.WithValue(ctx, enableOriginalTimestamp, true) +} + // Get the request ID from the context if specified, otherwise generate one func getOrGenerateRequestIDFromContext(ctx context.Context) UUID { requestID, ok := ctx.Value(snowflakeRequestIDKey).(UUID)