diff --git a/arrow_chunk.go b/arrow_chunk.go index 4d842e721..15851a80f 100644 --- a/arrow_chunk.go +++ b/arrow_chunk.go @@ -57,7 +57,7 @@ func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[ for arc.reader.Next() { rawRecord := arc.reader.Record() - record, err := arrowToRecord(rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc, scd.ctx) + record, err := arrowToRecord(scd.ctx, rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc) if err != nil { return nil, err } diff --git a/converter.go b/converter.go index 4a3bab19f..a1c0f05c7 100644 --- a/converter.go +++ b/converter.go @@ -980,7 +980,7 @@ func originalTimestampEnabled(ctx context.Context) bool { return ok && d } -func arrowToRecord(record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location, ctx context.Context) (arrow.Record, error) { +func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { useOriginalTimestamp := originalTimestampEnabled(ctx) s, err := recordToSchema(record.Schema(), rowType, loc, useOriginalTimestamp) @@ -990,7 +990,7 @@ func arrowToRecord(record arrow.Record, pool memory.Allocator, rowType []execRes var cols []arrow.Array numRows := record.NumRows() - ctxAlloc := compute.WithAllocator(context.Background(), pool) + ctxAlloc := compute.WithAllocator(ctx, pool) for i, col := range record.Columns() { srcColumnMeta := rowType[i] diff --git a/converter_test.go b/converter_test.go index d10a64de2..98dda7aa2 100644 --- a/converter_test.go +++ b/converter_test.go @@ -1516,7 +1516,7 @@ func TestArrowToRecord(t *testing.T) { ctx = WithOriginalTimestamp(ctx) } - transformedRec, err := arrowToRecord(rawRec, pool, []execResponseRowType{meta}, localTime.Location(), ctx) + transformedRec, err := arrowToRecord(ctx, rawRec, pool, []execResponseRowType{meta}, localTime.Location()) if err != nil { if tc.error == "" || !strings.Contains(err.Error(), tc.error) { t.Fatalf("error: %s", err)