Skip to content

Commit

Permalink
Parse timestamps using timezone (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
aldld authored Jul 21, 2022
1 parent 6312f8d commit 11dd2ed
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 11 deletions.
12 changes: 12 additions & 0 deletions databricks.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"io"
"io/ioutil"
"time"
)

func init() {
Expand All @@ -18,10 +19,21 @@ type Options struct {
HTTPPath string
MaxRows int64
Timeout int
Loc *time.Location

LogOut io.Writer
}

func (o *Options) Equal(o2 *Options) bool {
return o.Host == o2.Host &&
o.Port == o2.Port &&
o.Token == o2.Token &&
o.HTTPPath == o2.HTTPPath &&
o.MaxRows == o2.MaxRows &&
o.Timeout == o2.Timeout &&
o.Loc.String() == o2.Loc.String()
}

var (
// DefaultOptions for the driver
DefaultOptions = Options{Port: "443", MaxRows: 10000, LogOut: ioutil.Discard}
Expand Down
12 changes: 10 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func (d *Driver) Open(uri string) (driver.Conn, error) {
return nil, err
}

log.Printf("opts: %v", opts)
// (eric) Don't log opts because it contains sensitive information.
// log.Printf("opts: %v", opts)

conn, err := connect(opts)
if err != nil {
Expand Down Expand Up @@ -101,6 +102,13 @@ func parseURI(uri string) (*Options, error) {
}
}

opts.Loc = time.UTC
if tz, ok := query["tz"]; ok {
if loc, err := time.LoadLocation(tz[0]); err == nil {
opts.Loc = loc
}
}

return &opts, nil
}

Expand Down Expand Up @@ -166,7 +174,7 @@ func connect(opts *Options) (*Conn, error) {
protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
tclient := thrift.NewTStandardClient(protocolFactory.GetProtocol(transport), protocolFactory.GetProtocol(transport))

client := hive.NewClient(tclient, logger, &hive.Options{MaxRows: opts.MaxRows})
client := hive.NewClient(tclient, logger, &hive.Options{MaxRows: opts.MaxRows, Loc: opts.Loc})

return &Conn{client: client, t: transport, log: logger}, nil
}
22 changes: 16 additions & 6 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,42 @@ package dbsql
import (
"io/ioutil"
"testing"
"time"
)

func TestParseURI(t *testing.T) {
americaLosAngelesTz, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
t.Fatal(err)
}

tests := []struct {
in string
out Options
}{
{
"databricks://token:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 10000, Timeout: 0, LogOut: ioutil.Discard},
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 10000, Timeout: 0, LogOut: ioutil.Discard, Loc: time.UTC},
},
{
"databricks://token:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a?timeout=123",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", Timeout: 123, MaxRows: 10000, LogOut: ioutil.Discard},
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", Timeout: 123, MaxRows: 10000, LogOut: ioutil.Discard, Loc: time.UTC},
},
{
"databricks://token:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a?timeout=123&maxRows=1000",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", Timeout: 123, MaxRows: 1000, LogOut: ioutil.Discard},
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", Timeout: 123, MaxRows: 1000, LogOut: ioutil.Discard, Loc: time.UTC},
},
{
"databricks://token:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a?maxRows=1000",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 1000, Timeout: 0, LogOut: ioutil.Discard},
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 1000, Timeout: 0, LogOut: ioutil.Discard, Loc: time.UTC},
},
{
"databricks://:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a?maxRows=1000",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 1000, Timeout: 0, LogOut: ioutil.Discard},
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 1000, Timeout: 0, LogOut: ioutil.Discard, Loc: time.UTC},
},
{
"databricks://:[email protected]/sql/1.0/endpoints/12346a5b5b0e123a?maxRows=1000&tz=America%2FLos_Angeles",
Options{Host: "example.cloud.databricks.com", Port: "443", Token: "supersecret", HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", MaxRows: 1000, Timeout: 0, LogOut: ioutil.Discard, Loc: americaLosAngelesTz},
},
}

Expand All @@ -39,7 +49,7 @@ func TestParseURI(t *testing.T) {
t.Error(err)
return
}
if *opts != tt.out {
if !opts.Equal(&tt.out) {
t.Errorf("got: %v, want: %v", opts, tt.out)
}
})
Expand Down
2 changes: 2 additions & 0 deletions hive/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hive
import (
"context"
"log"
"time"

"github.com/apache/thrift/lib/go/thrift"
"github.com/databricks/databricks-sql-go/cli_service"
Expand All @@ -18,6 +19,7 @@ type Client struct {
// Options for Hive Client
type Options struct {
MaxRows int64
Loc *time.Location
}

// NewClient creates Hive Client
Expand Down
1 change: 1 addition & 0 deletions hive/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func (op *Operation) FetchResults(ctx context.Context, schema *TableSchema) (*Re
result: resp.Results,
more: resp.GetHasMoreRows(),
schema: schema,
loc: op.hive.opts.Loc,
fetchfn: func() (*cli_service.TFetchResultsResp, error) { return fetch(ctx, op, schema) },
}

Expand Down
11 changes: 8 additions & 3 deletions hive/result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type ResultSet struct {
length int
fetchfn func() (*cli_service.TFetchResultsResp, error)
schema *TableSchema
loc *time.Location

operation *Operation
result *cli_service.TRowSet
Expand Down Expand Up @@ -42,7 +43,7 @@ func (rs *ResultSet) Next(dest []driver.Value) error {
}

for i := range dest {
val, err := value(rs.result.Columns[i], rs.schema.Columns[i], rs.idx)
val, err := value(rs.result.Columns[i], rs.schema.Columns[i], rs.idx, rs.loc)
if err != nil {
return err
}
Expand All @@ -61,7 +62,11 @@ func isNull(nulls []byte, position int) bool {
return false
}

func value(col *cli_service.TColumn, cd *ColDesc, i int) (interface{}, error) {
func value(col *cli_service.TColumn, cd *ColDesc, i int, loc *time.Location) (interface{}, error) {
if loc == nil {
loc = time.UTC
}

switch cd.DatabaseTypeName {
case "STRING", "CHAR", "VARCHAR":
if isNull(col.StringVal.Nulls, i) {
Expand Down Expand Up @@ -102,7 +107,7 @@ func value(col *cli_service.TColumn, cd *ColDesc, i int) (interface{}, error) {
if isNull(col.StringVal.Nulls, i) {
return nil, nil
}
t, err := time.Parse(TimestampFormat, col.StringVal.Values[i])
t, err := time.ParseInLocation(TimestampFormat, col.StringVal.Values[i], loc)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 11dd2ed

Please sign in to comment.