diff --git a/chunk_downloader.go b/chunk_downloader.go index 6806e5e32..559395e7b 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -106,7 +106,7 @@ func (scd *snowflakeChunkDownloader) start() error { // if the rowsetbase64 retrieved from the server is empty, move on to downloading chunks var err error var loc *time.Location - if scd.sc != nil && scd.sc.cfg != nil { + if scd.sc != nil && scd.sc.getConnConfig() != nil { loc = getCurrentLocation(scd.sc.cfg.Params) } firstArrowChunk := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) @@ -271,7 +271,7 @@ func (scd *snowflakeChunkDownloader) startArrowBatches() error { var err error chunkMetaLen := len(scd.ChunkMetas) var loc *time.Location - if scd.sc != nil && scd.sc.cfg != nil { + if scd.sc != nil && scd.sc.getConnConfig() != nil { loc = getCurrentLocation(scd.sc.cfg.Params) } firstArrowChunk := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) @@ -432,7 +432,7 @@ func decodeChunk(scd *snowflakeChunkDownloader, idx int, bufStream *bufio.Reader return err } var loc *time.Location - if scd.sc != nil && scd.sc.cfg != nil { + if scd.sc != nil && scd.sc.getConnConfig() != nil { loc = getCurrentLocation(scd.sc.cfg.Params) } arc := arrowResultChunk{ diff --git a/connection.go b/connection.go index c9d760327..9cd4e132b 100644 --- a/connection.go +++ b/connection.go @@ -249,7 +249,9 @@ func (sc *snowflakeConn) cleanup() { sc.rest.Client.CloseIdleConnections() } sc.rest = nil + paramsMutex.Lock() sc.cfg = nil + paramsMutex.Unlock() } func (sc *snowflakeConn) Close() (err error) { @@ -258,7 +260,7 @@ func (sc *snowflakeConn) Close() (err error) { sc.stopHeartBeat() defer sc.cleanup() - if sc.cfg != nil && !sc.cfg.KeepSessionAlive { + if sc.getConnConfig() != nil && !sc.cfg.KeepSessionAlive { if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { logger.Error(err) } diff --git a/connection_test.go b/connection_test.go index 0a23684d2..cfcffbeda 100644 --- a/connection_test.go +++ b/connection_test.go @@ -476,7 +476,6 @@ func TestExecWithServerSideError(t *testing.T) { } func TestConcurrentReadOnParams(t *testing.T) { - t.Skip("Fails randomly") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") diff --git a/connection_util.go b/connection_util.go index 4d37dea28..552b3bb84 100644 --- a/connection_util.go +++ b/connection_util.go @@ -13,6 +13,12 @@ import ( "time" ) +func (sc *snowflakeConn) getConnConfig() *Config { + paramsMutex.Lock() + defer paramsMutex.Unlock() + return sc.cfg +} + func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { paramsMutex.Lock() v, ok := sc.cfg.Params[sessionClientSessionKeepAlive] @@ -24,7 +30,7 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { } func (sc *snowflakeConn) startHeartBeat() { - if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { + if sc.getConnConfig() != nil && !sc.isClientSessionKeepAliveEnabled() { return } if sc.rest != nil { @@ -36,7 +42,7 @@ func (sc *snowflakeConn) startHeartBeat() { } func (sc *snowflakeConn) stopHeartBeat() { - if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { + if sc.getConnConfig() != nil && !sc.isClientSessionKeepAliveEnabled() { return } if sc.rest != nil && sc.rest.HeartBeat != nil { diff --git a/htap.go b/htap.go index 93806801f..09a12f282 100644 --- a/htap.go +++ b/htap.go @@ -79,7 +79,11 @@ func (qcc *queryContextCache) prune(size int) { } func (qcc *queryContextCache) getQueryContextCacheSize(sc *snowflakeConn) int { - if sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName]; ok { + paramsMutex.Lock() + sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName] + paramsMutex.Unlock() + + if ok { size, err := strconv.Atoi(*sizeStr) if err != nil { logger.Warnf("cannot parse %v as int as query context cache size: %v", sizeStr, err) diff --git a/rows.go b/rows.go index 3d3fcbb0f..6bd82acc7 100644 --- a/rows.go +++ b/rows.go @@ -47,7 +47,7 @@ type snowflakeRows struct { } func (rows *snowflakeRows) getLocation() *time.Location { - if rows.location == nil && rows.sc != nil && rows.sc.cfg != nil { + if rows.location == nil && rows.sc != nil && rows.sc.getConnConfig() != nil { rows.location = getCurrentLocation(rows.sc.cfg.Params) } return rows.location