From 1360d7c68bf1018e8bea7593ba204a3a16d419fb Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:14:57 -0700 Subject: [PATCH 01/19] fix partial coverage --- arrow_chunk.go | 3 ++- async.go | 3 ++- authexternalbrowser.go | 3 ++- authokta.go | 6 ++++-- azure_storage_client.go | 3 ++- bind_uploader.go | 3 ++- chunk.go | 3 ++- connection.go | 23 +++++++++++++------- connection_util.go | 11 ++++++---- converter.go | 3 ++- dsn.go | 6 ++++-- encrypt_util.go | 6 ++++-- errors.go | 3 ++- file_transfer_agent.go | 45 ++++++++++++++++++++++++++------------- gcs_storage_client.go | 9 +++++--- local_storage_client.go | 9 +++++--- monitoring.go | 9 +++++--- multistatement.go | 3 ++- ocsp.go | 12 +++++++---- restful.go | 21 ++++++++++++------ result.go | 6 ++++-- rows.go | 35 +++++++++++++++++++----------- s3_storage_client.go | 5 +++-- secure_storage_manager.go | 18 ++++++++++------ storage_client.go | 12 +++++++---- telemetry.go | 6 ++++-- 26 files changed, 176 insertions(+), 90 deletions(-) diff --git a/arrow_chunk.go b/arrow_chunk.go index 344774af8..f98c41e07 100644 --- a/arrow_chunk.go +++ b/arrow_chunk.go @@ -36,7 +36,8 @@ func (arc *arrowResultChunk) decodeArrowChunk(rowType []execResponseRowType, hig for colIdx, col := range columns { values := make([]snowflakeValue, numRows) - if err := arrowToValue(values, rowType[colIdx], col, arc.loc, highPrec); err != nil { + err := arrowToValue(values, rowType[colIdx], col, arc.loc, highPrec) + if err != nil { return nil, err } diff --git a/async.go b/async.go index 47da1b7b4..d2db50ce8 100644 --- a/async.go +++ b/async.go @@ -132,7 +132,8 @@ func (sr *snowflakeRestful) getAsync( rows.sc = sc rows.queryID = respd.Data.QueryID if isMultiStmt(&respd.Data) { - if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil { + err = sc.handleMultiQuery(ctx, respd.Data, rows) + if err != nil { rows.errChannel <- err return err } diff --git a/authexternalbrowser.go b/authexternalbrowser.go index a8d966cef..56df0feeb 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -226,7 +226,8 @@ func doAuthenticateByExternalBrowser( return authenticateByExternalBrowserResult{nil, nil, err} } - if err = openBrowser(idpURL); err != nil { + err = openBrowser(idpURL) + if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } diff --git a/authokta.go b/authokta.go index 994b51c13..35bb2a60a 100644 --- a/authokta.go +++ b/authokta.go @@ -108,10 +108,12 @@ func authenticateBySAML( logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL") var tokenURL *url.URL var ssoURL *url.URL - if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { + tokenURL, err = url.Parse(respd.Data.TokenURL) + if err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } - if ssoURL, err = url.Parse(respd.Data.TokenURL); err != nil { + ssoURL, err = url.Parse(respd.Data.TokenURL) + if err != nil { return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { diff --git a/azure_storage_client.go b/azure_storage_client.go index 0db3a4cb3..bc2091415 100644 --- a/azure_storage_client.go +++ b/azure_storage_client.go @@ -106,7 +106,8 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str _, ok = metadata["Encryptiondata"] if ok { - if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil { + err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData) + if err != nil { return nil, err } } diff --git a/bind_uploader.go b/bind_uploader.go index 414bbb83f..eb183737f 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -66,7 +66,8 @@ func (bu *bindUploader) uploadStreamInternal( dstFileName int, compressData bool) ( *execResponse, error) { - if err := bu.createStageIfNeeded(); err != nil { + err := bu.createStageIfNeeded() + if err != nil { return nil, err } stageName := bu.stagePath diff --git a/chunk.go b/chunk.go index 4708f6282..531f15254 100644 --- a/chunk.go +++ b/chunk.go @@ -146,7 +146,8 @@ func (lcd *largeChunkDecoder) decodeString() (string, error) { if c == '"' { break } else if c == '\\' { - if err := lcd.decodeEscaped(); err != nil { + err := lcd.decodeEscaped() + if err != nil { return "", err } } else if c < ' ' { diff --git a/connection.go b/connection.go index 769253d46..6315ecb91 100644 --- a/connection.go +++ b/connection.go @@ -108,7 +108,8 @@ func (sc *snowflakeConn) exec( // handle bindings, if required requestID := getOrGenerateRequestIDFromContext(ctx) if len(bindings) > 0 { - if err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req); err != nil { + err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req) + if err != nil { return nil, err } } @@ -236,8 +237,9 @@ func (sc *snowflakeConn) BeginTx( return nil, driver.ErrBadConn } isDesc := isDescribeOnly(ctx) - if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ - false /* isInternal */, isDesc, nil); err != nil { + _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ + false /* isInternal */, isDesc, nil) + if err != nil { return nil, err } return &snowflakeTx{sc, ctx}, nil @@ -259,7 +261,8 @@ func (sc *snowflakeConn) Close() (err error) { defer sc.cleanup() if sc.cfg != nil && !sc.cfg.KeepSessionAlive { - if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { + err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout) + if err != nil { logger.Error(err) } } @@ -402,7 +405,8 @@ func (sc *snowflakeConn) queryContextInternal( if isMultiStmt(&data.Data) { // handleMultiQuery is responsible to fill rows with childResults - if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { + err = sc.handleMultiQuery(ctx, data.Data, rows) + if err != nil { return nil, err } } else { @@ -537,7 +541,8 @@ type wrapReader struct { func (w *wrapReader) Close() error { if cl, ok := w.Reader.(io.ReadCloser); ok { - if err := cl.Close(); err != nil { + err := cl.Close() + if err != nil { return err } } @@ -613,7 +618,8 @@ func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) erro // to ensure no leaked memory. func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { if asb.rr == nil { - if err := asb.downloadChunkStreamHelper(ctx); err != nil { + err := asb.downloadChunkStreamHelper(ctx) + if err != nil { return nil, err } } @@ -750,7 +756,8 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err st = sc.cfg.Transporter } if strings.HasSuffix(sc.cfg.Host, privateLinkSuffix) { - if err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host); err != nil { + err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host) + if err != nil { return nil, err } } else { diff --git a/connection_util.go b/connection_util.go index 4d37dea28..c15c32b42 100644 --- a/connection_util.go +++ b/connection_util.go @@ -103,10 +103,11 @@ func (sc *snowflakeConn) processFileTransfer( if sfa.options.MultiPartThreshold == 0 { sfa.options.MultiPartThreshold = dataSizeThreshold } - if err := sfa.execute(); err != nil { + err := sfa.execute() + if err != nil { return nil, err } - data, err := sfa.result() + data, err = sfa.result() if err != nil { return nil, err } @@ -282,12 +283,14 @@ func populateChunkDownloader( func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host) logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) - if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil { + err := os.Setenv(cacheServerURLEnv, ocspCacheServer) + if err != nil { return err } ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) - if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil { + err = os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate) + if err != nil { return err } return nil diff --git a/converter.go b/converter.go index 44c0afbfa..6fc4fb702 100644 --- a/converter.go +++ b/converter.go @@ -77,7 +77,8 @@ func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType { if len(t) != 1 { return unSupportedType } - if _, err := dataTypeMode(t); err != nil { + _, err := dataTypeMode(t) + if err != nil { return unSupportedType } return changeType diff --git a/dsn.go b/dsn.go index d5fa2ab73..d96510f6d 100644 --- a/dsn.go +++ b/dsn.go @@ -105,7 +105,8 @@ type Config struct { // A driver client may call it manually, but it is also called during opening first connection. func (c *Config) Validate() error { if c.TmpDirPath != "" { - if _, err := os.Stat(c.TmpDirPath); err != nil { + _, err := os.Stat(c.TmpDirPath) + if err != nil { return err } } @@ -551,7 +552,8 @@ func parseUserPassword(posAt int, dsn string) (user, password string) { func parseParams(cfg *Config, posQuestion int, dsn string) (err error) { for j := posQuestion + 1; j < len(dsn); j++ { if dsn[j] == '?' { - if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + err = parseDSNParams(cfg, dsn[j+1:]) + if err != nil { return } break diff --git a/encrypt_util.go b/encrypt_util.go index 435ecbafc..17111c3fa 100644 --- a/encrypt_util.go +++ b/encrypt_util.go @@ -104,7 +104,8 @@ func encryptStream( // encrypt key with ECB fileKey = padBytesLength(fileKey, block.BlockSize()) encryptedFileKey := make([]byte, len(fileKey)) - if err = encryptECB(encryptedFileKey, fileKey, decodedKey); err != nil { + err = encryptECB(encryptedFileKey, fileKey, decodedKey) + if err != nil { return nil, err } @@ -212,7 +213,8 @@ func decryptFile( // decrypt file key decryptedKey := make([]byte, len(keyBytes)) - if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil { + err = decryptECB(decryptedKey, keyBytes, decodedKey) + if err != nil { return "", err } decryptedKey, err = paddingTrim(decryptedKey) diff --git a/errors.go b/errors.go index a4104f66c..656ae0f1b 100644 --- a/errors.go +++ b/errors.go @@ -74,7 +74,8 @@ func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *teleme func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError { data := se.generateTelemetryExceptionData() - if err := se.sendExceptionTelemetry(sc, data); err != nil { + err := se.sendExceptionTelemetry(sc, data) + if err != nil { logger.Debugf("failed to log to telemetry: %v", data) } return se diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 9e3a64697..79a878c6a 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -129,26 +129,31 @@ type snowflakeFileTransferAgent struct { func (sfa *snowflakeFileTransferAgent) execute() error { var err error - if err = sfa.parseCommand(); err != nil { + err = sfa.parseCommand() + if err != nil { return err } - if err = sfa.initFileMetadata(); err != nil { + err = sfa.initFileMetadata() + if err != nil { return err } if sfa.commandType == uploadCommand { - if err = sfa.processFileCompressionType(); err != nil { + err = sfa.processFileCompressionType() + if err != nil { return err } } - if err = sfa.transferAccelerateConfig(); err != nil { + err = sfa.transferAccelerateConfig() + if err != nil { return err } if sfa.commandType == downloadCommand { if _, err = os.Stat(sfa.localLocation); os.IsNotExist(err) { - if err = os.MkdirAll(sfa.localLocation, os.ModePerm); err != nil { + err = os.MkdirAll(sfa.localLocation, os.ModePerm) + if err != nil { return err } } @@ -156,13 +161,15 @@ func (sfa *snowflakeFileTransferAgent) execute() error { if sfa.stageLocationType == local { if _, err = os.Stat(sfa.stageInfo.Location); os.IsNotExist(err) { - if err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm); err != nil { + err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm) + if err != nil { return err } } } - if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { + err = sfa.updateFileMetadataWithPresignedURL() + if err != nil { return err } @@ -190,11 +197,13 @@ func (sfa *snowflakeFileTransferAgent) execute() error { } if sfa.commandType == uploadCommand { - if err = sfa.upload(largeFileMetas, smallFileMetas); err != nil { + err = sfa.upload(largeFileMetas, smallFileMetas) + if err != nil { return err } } else { - if err = sfa.download(smallFileMetas); err != nil { + err = sfa.download(smallFileMetas) + if err != nil { return err } } @@ -253,7 +262,8 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error { if err != nil { return err } - if fi, err := os.Stat(sfa.localLocation); err != nil || !fi.IsDir() { + fi, err := os.Stat(sfa.localLocation) + if err != nil || !fi.IsDir() { return (&SnowflakeError{ Number: ErrLocalPathNotDirectory, SQLState: sfa.data.SQLState, @@ -683,13 +693,15 @@ func (sfa *snowflakeFileTransferAgent) upload( if len(smallFileMetadata) > 0 { logger.Infof("uploading %v small files", len(smallFileMetadata)) - if err = sfa.uploadFilesParallel(smallFileMetadata); err != nil { + err = sfa.uploadFilesParallel(smallFileMetadata) + if err != nil { return err } } if len(largeFileMetadata) > 0 { logger.Infof("uploading %v large files", len(largeFileMetadata)) - if err = sfa.uploadFilesSequential(largeFileMetadata); err != nil { + err = sfa.uploadFilesSequential(largeFileMetadata) + if err != nil { return err } } @@ -708,7 +720,8 @@ func (sfa *snowflakeFileTransferAgent) download( } logger.WithContext(sfa.sc.ctx).Infof("downloading %v files", len(fileMetadata)) - if err = sfa.downloadFilesParallel(fileMetadata); err != nil { + err = sfa.downloadFilesParallel(fileMetadata) + if err != nil { return err } return nil @@ -869,7 +882,8 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM } client := sfa.getStorageClient(sfa.stageLocationType) - if err = client.uploadOneFileWithRetry(meta); err != nil { + err = client.uploadOneFileWithRetry(meta) + if err != nil { return meta, err } return meta, nil @@ -958,7 +972,8 @@ func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fil meta.tmpDir = tmpDir defer os.RemoveAll(tmpDir) // cleanup client := sfa.getStorageClient(sfa.stageLocationType) - if err = client.downloadOneFile(meta); err != nil { + err = client.downloadOneFile(meta) + if err != nil { meta.dstFileSize = -1 if !meta.resStatus.isSet() { meta.resStatus = errStatus diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 5d9dac9a8..28b71804f 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -249,7 +249,8 @@ func (util *snowflakeGcsClient) uploadFile( meta.gcsFileHeaderDigest = gcsHeaders[gcsFileHeaderDigest] meta.gcsFileHeaderContentLength = meta.uploadSize - if err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &encryptMeta); err != nil { + err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &encryptMeta) + if err != nil { return err } meta.gcsFileHeaderEncryptionMeta = encryptMeta @@ -319,14 +320,16 @@ func (util *snowflakeGcsClient) nativeDownloadFile( return err } defer f.Close() - if _, err = io.Copy(f, resp.Body); err != nil { + _, err = io.Copy(f, resp.Body) + if err != nil { return err } var encryptMeta encryptMetadata if resp.Header.Get(gcsMetadataEncryptionDataProp) != "" { var encryptData *encryptionData - if err = json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData); err != nil { + err = json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData) + if err != nil { return err } if encryptData != nil { diff --git a/local_storage_client.go b/local_storage_client.go index 2ae072b63..1b6702221 100644 --- a/local_storage_client.go +++ b/local_storage_client.go @@ -62,7 +62,8 @@ func (util *localUtil) uploadOneFileWithRetry(meta *fileMetadata) error { break } - if _, err = output.Write(data); err != nil { + _, err = output.Write(data) + if err != nil { return err } } @@ -91,7 +92,8 @@ func (util *localUtil) downloadOneFile(meta *fileMetadata) error { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { - if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { return err } } @@ -100,7 +102,8 @@ func (util *localUtil) downloadOneFile(meta *fileMetadata) error { if err != nil { return err } - if err = os.WriteFile(fullDstFileName, data, readWriteFileMode); err != nil { + err = os.WriteFile(fullDstFileName, data, readWriteFileMode) + if err != nil { return err } fi, err := os.Stat(fullDstFileName) diff --git a/monitoring.go b/monitoring.go index 07b17c0aa..06535718e 100644 --- a/monitoring.go +++ b/monitoring.go @@ -146,7 +146,8 @@ func (sc *snowflakeConn) checkQueryStatus( } defer res.Body.Close() var statusResp = statusResponse{} - if err = json.NewDecoder(res.Body).Decode(&statusResp); err != nil { + err = json.NewDecoder(res.Body).Decode(&statusResp) + if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } @@ -221,7 +222,8 @@ func (sc *snowflakeConn) getQueryResultResp( } defer res.Body.Close() var respd *execResponse - if err = json.NewDecoder(res.Body).Decode(&respd); err != nil { + err = json.NewDecoder(res.Body).Decode(&respd) + if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } @@ -262,7 +264,8 @@ func (sc *snowflakeConn) buildRowsForRunningQuery( rows := new(snowflakeRows) rows.sc = sc rows.queryID = qid - if err := sc.rowsForRunningQuery(ctx, qid, rows); err != nil { + err := sc.rowsForRunningQuery(ctx, qid, rows) + if err != nil { return nil, err } rows.ChunkDownloader.start() diff --git a/multistatement.go b/multistatement.go index ce9d9910b..af579681c 100644 --- a/multistatement.go +++ b/multistatement.go @@ -98,7 +98,8 @@ func (sc *snowflakeConn) handleMultiQuery( } childResults := getChildResults(data.ResultIDs, data.ResultTypes) for _, child := range childResults { - if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil { + err := sc.rowsForRunningQuery(ctx, child.id, rows) + if err != nil { return err } } diff --git a/ocsp.go b/ocsp.go index c809a24c6..33f75f562 100644 --- a/ocsp.go +++ b/ocsp.go @@ -920,11 +920,13 @@ func writeOCSPCacheFile() { logger.Debugf("other process locks the cache file. %v. ignored.\n", cacheLockFileName) return } - if err = os.Remove(cacheLockFileName); err != nil { + err = os.Remove(cacheLockFileName) + if err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } - if err = os.Mkdir(cacheLockFileName, 0600); err != nil { + err = os.Mkdir(cacheLockFileName, 0600) + if err != nil { logger.Debugf("failed to create lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } @@ -947,7 +949,8 @@ func writeOCSPCacheFile() { logger.Debugf("failed to convert OCSP Response cache to JSON. ignored.") return } - if err = os.WriteFile(cacheFileName, j, 0644); err != nil { + err = os.WriteFile(cacheFileName, j, 0644) + if err != nil { logger.Debugf("failed to write OCSP Response cache. err: %v. ignored.\n", err) } } @@ -1006,7 +1009,8 @@ func createOCSPCacheDir() { } if _, err := os.Stat(cacheDir); os.IsNotExist(err) { - if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { + err = os.MkdirAll(cacheDir, os.ModePerm) + if err != nil { logger.Debugf("failed to create cache directory. %v, err: %v. ignored\n", cacheDir, err) } } diff --git a/restful.go b/restful.go index 34297c1d9..a95eb6466 100644 --- a/restful.go +++ b/restful.go @@ -216,7 +216,8 @@ func postRestfulQuery( return data, err } - if err = sr.FuncCancelQuery(context.TODO(), sr, requestID, timeout); err != nil { + err = sr.FuncCancelQuery(context.TODO(), sr, requestID, timeout) + if err != nil { return nil, err } return nil, ctx.Err() @@ -251,12 +252,14 @@ func postRestfulQueryHelper( if resp.StatusCode == http.StatusOK { logger.WithContext(ctx).Infof("postQuery: resp: %v", resp) var respd execResponse - if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { + err = json.NewDecoder(resp.Body).Decode(&respd) + if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } if respd.Code == sessionExpiredCode { - if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { + err = sr.renewExpiredSessionToken(ctx, timeout, token) + if err != nil { return nil, err } return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg) @@ -297,7 +300,8 @@ func postRestfulQueryHelper( return nil, err } if respd.Code == sessionExpiredCode { - if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { + err = sr.renewExpiredSessionToken(ctx, timeout, token) + if err != nil { return nil, err } isSessionRenewed = true @@ -341,7 +345,8 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati defer resp.Body.Close() if resp.StatusCode == http.StatusOK { var respd renewSessionResponse - if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { + err = json.NewDecoder(resp.Body).Decode(&respd) + if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } @@ -472,13 +477,15 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time defer resp.Body.Close() if resp.StatusCode == http.StatusOK { var respd cancelQueryResponse - if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { + err = json.NewDecoder(resp.Body).Decode(&respd) + if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } ctxRetry := getCancelRetry(ctx) if !respd.Success && respd.Code == sessionExpiredCode { - if err = sr.FuncRenewSession(ctx, sr, timeout); err != nil { + err = sr.FuncRenewSession(ctx, sr, timeout) + if err != nil { return err } return sr.FuncCancelQuery(ctx, sr, requestID, timeout) diff --git a/result.go b/result.go index e08f41902..707b4c026 100644 --- a/result.go +++ b/result.go @@ -30,14 +30,16 @@ type snowflakeResult struct { } func (res *snowflakeResult) LastInsertId() (int64, error) { - if err := res.waitForAsyncExecStatus(); err != nil { + err := res.waitForAsyncExecStatus() + if err != nil { return -1, err } return res.insertID, nil } func (res *snowflakeResult) RowsAffected() (int64, error) { - if err := res.waitForAsyncExecStatus(); err != nil { + err := res.waitForAsyncExecStatus() + if err != nil { return -1, err } return res.affectedRows, nil diff --git a/rows.go b/rows.go index 83f49ba94..587e7bb96 100644 --- a/rows.go +++ b/rows.go @@ -71,8 +71,9 @@ type chunkError struct { Error error } -func (rows *snowflakeRows) Close() (err error) { - if err := rows.waitForAsyncQueryStatus(); err != nil { +func (rows *snowflakeRows) Close() error { + err := rows.waitForAsyncQueryStatus() + if err != nil { return err } logger.WithContext(rows.sc.ctx).Debugln("Rows.Close") @@ -81,7 +82,8 @@ func (rows *snowflakeRows) Close() (err error) { // ColumnTypeDatabaseTypeName returns the database column name. func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return err.Error() } return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type) @@ -89,7 +91,8 @@ func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { // ColumnTypeLength returns the length of the column func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return 0, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -103,7 +106,8 @@ func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { } func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return false, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -113,7 +117,8 @@ func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { } func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return 0, 0, false } rowType := rows.ChunkDownloader.getRowType() @@ -132,7 +137,8 @@ func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale } func (rows *snowflakeRows) Columns() []string { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return make([]string, 0) } logger.Debug("Rows.Columns") @@ -144,7 +150,8 @@ func (rows *snowflakeRows) Columns() []string { } func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return nil } return snowflakeTypeToGo( @@ -164,7 +171,8 @@ func (rows *snowflakeRows) GetStatus() queryStatus { func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) { // Wait for all arrow batches before fetching. // Otherwise, a panic error "invalid memory address or nil pointer dereference" will be thrown. - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return nil, err } @@ -172,7 +180,8 @@ func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) { } func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { - if err = rows.waitForAsyncQueryStatus(); err != nil { + err = rows.waitForAsyncQueryStatus() + if err != nil { return err } row, err := rows.ChunkDownloader.next() @@ -202,14 +211,16 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { } func (rows *snowflakeRows) HasNextResultSet() bool { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return false } return rows.ChunkDownloader.hasNextResultSet() } func (rows *snowflakeRows) NextResultSet() error { - if err := rows.waitForAsyncQueryStatus(); err != nil { + err := rows.waitForAsyncQueryStatus() + if err != nil { return err } if len(rows.ChunkDownloader.getChunkMetas()) == 0 { diff --git a/s3_storage_client.go b/s3_storage_client.go index ed3bca59a..c655f3fdb 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -230,10 +230,11 @@ func (util *snowflakeS3Client) nativeDownloadFile( if meta.mockDownloader != nil { downloader = meta.mockDownloader } - if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ + _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ Bucket: s3Obj.Bucket, Key: s3Obj.Key, - }); err != nil { + }) + if err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == expiredToken { diff --git a/secure_storage_manager.go b/secure_storage_manager.go index 9b83a2be7..1c3146c0d 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -52,7 +52,8 @@ func createCredentialCacheDir() { } if _, err := os.Stat(credCacheDir); os.IsNotExist(err) { - if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil { + err = os.MkdirAll(credCacheDir, os.ModePerm) + if err != nil { logger.Debugf("Failed to create cache directory. %v, err: %v. ignored\n", credCacheDir, err) } } @@ -75,7 +76,8 @@ func setCredential(sc *snowflakeConn, credType, token string) { Key: target, Data: []byte(token), } - if err := ring.Set(item); err != nil { + err := ring.Set(item) + if err != nil { logger.Debugf("Failed to write to Windows credential manager. Err: %v", err) } } else if runtime.GOOS == "darwin" { @@ -88,7 +90,8 @@ func setCredential(sc *snowflakeConn, credType, token string) { Key: account, Data: []byte(token), } - if err := ring.Set(item); err != nil { + err := ring.Set(item) + if err != nil { logger.Debugf("Failed to write to keychain. Err: %v", err) } } else if runtime.GOOS == "linux" { @@ -256,18 +259,21 @@ func writeTemporaryCacheFile(input []byte) { logger.Debugf("other process locks the cache file. %v. ignored.\n", credCache) return } - if err = os.Remove(credCacheLockFileName); err != nil { + err = os.Remove(credCacheLockFileName) + if err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) return } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { + err = os.Mkdir(credCacheLockFileName, 0600) + if err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) return } } defer os.RemoveAll(credCacheLockFileName) - if err = os.WriteFile(credCache, input, 0644); err != nil { + err = os.WriteFile(credCache, input, 0644) + if err != nil { logger.Debugf("Failed to write the cache file. File: %v err: %v.", credCache, err) } } diff --git a/storage_client.go b/storage_client.go index a7385c278..e489e23f7 100644 --- a/storage_client.go +++ b/storage_client.go @@ -133,7 +133,8 @@ func (rsu *remoteStorageUtil) uploadOneFileWithRetry(meta *fileMetadata) error { retryOuter := true for i := 0; i < 10; i++ { // retry - if err := rsu.uploadOneFile(meta); err != nil { + err := rsu.uploadOneFile(meta) + if err != nil { return err } retryInner := true @@ -184,7 +185,8 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { - if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { return err } } @@ -202,7 +204,8 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { var lastErr error maxRetry := defaultMaxRetry for retry := 0; retry < maxRetry; retry++ { - if err = utilClass.nativeDownloadFile(meta, fullDstFileName, maxConcurrency); err != nil { + err = utilClass.nativeDownloadFile(meta, fullDstFileName, maxConcurrency) + if err != nil { return err } if meta.resStatus == downloaded { @@ -218,7 +221,8 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { if err != nil { return err } - if err = os.Rename(tmpDstFileName, fullDstFileName); err != nil { + err = os.Rename(tmpDstFileName, fullDstFileName) + if err != nil { return err } } diff --git a/telemetry.go b/telemetry.go index 1adb085a4..bd3706367 100644 --- a/telemetry.go +++ b/telemetry.go @@ -56,7 +56,8 @@ func (st *snowflakeTelemetry) addLog(data *telemetryData) error { st.logs = append(st.logs, data) st.mutex.Unlock() if len(st.logs) >= st.flushSize { - if err := st.sendBatch(); err != nil { + err := st.sendBatch() + if err != nil { return err } } @@ -111,7 +112,8 @@ func (st *snowflakeTelemetry) sendBatch() error { return err } var respd telemetryResponse - if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { + err = json.NewDecoder(resp.Body).Decode(&respd) + if err != nil { logger.Info(err) st.enabled = false return err From 85f49900111b8c6b367595cfc8665322e3877b04 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:19:21 -0700 Subject: [PATCH 02/19] fix partial coverage --- driver.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/driver.go b/driver.go index af29d66e1..4fe2b93d2 100644 --- a/driver.go +++ b/driver.go @@ -28,7 +28,8 @@ func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { // OpenWithConfig creates a new connection with the given Config. func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (driver.Conn, error) { - if err := config.Validate(); err != nil { + err := config.Validate() + if err != nil { return nil, err } if config.Tracing != "" { @@ -40,7 +41,8 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri return nil, err } - if err = authenticateWithConfig(sc); err != nil { + err = authenticateWithConfig(sc) + if err != nil { return nil, err } sc.connectionTelemetry(&config) From 69b6b09e692d359a588835be7593fafda4355257 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:19:48 -0700 Subject: [PATCH 03/19] increase coverage in connector.go --- connector_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/connector_test.go b/connector_test.go index 6c6b3caf4..fae283b74 100644 --- a/connector_test.go +++ b/connector_test.go @@ -47,3 +47,27 @@ func TestConnector(t *testing.T) { t.Fatalf("Missing driver") } } + +func TestConnectorWithMissingConfig(t *testing.T) { + conn := snowflakeConn{} + mock := noopTestDriver{conn: &conn} + config := Config{ + User: "u", + Password: "p", + Account: "", + } + expectedErr := errEmptyAccount() + + connector := NewConnector(&mock, config) + _, err := connector.Connect(context.Background()) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("Snowflake error is expected. err: %v", err.Error()) + } + if driverErr.Number != expectedErr.Number || driverErr.Message != expectedErr.Message { + t.Fatalf("Snowflake error did not match. expected: %v, got: %v", expectedErr, driverErr) + } +} From a33bb507111f8517b7e47d36dc96d8a3411460d6 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 4 Oct 2023 19:13:16 -0700 Subject: [PATCH 04/19] add test for heartbeat.stop() --- connection_util.go | 7 +++++++ heartbeat_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/connection_util.go b/connection_util.go index c15c32b42..31e7d5a76 100644 --- a/connection_util.go +++ b/connection_util.go @@ -23,6 +23,13 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { return strings.Compare(*v, "true") == 0 } +func (sc *snowflakeConn) isHeartbeatNil() bool { + if sc.rest != nil { + return sc.rest.HeartBeat == nil + } + return true +} + func (sc *snowflakeConn) startHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return diff --git a/heartbeat_test.go b/heartbeat_test.go index 17235f57b..c1023129d 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -3,6 +3,7 @@ package gosnowflake import ( + "context" "testing" ) @@ -43,3 +44,29 @@ func TestUnitPostHeartbeat(t *testing.T) { } }) } + +func TestHeartbeatStartAndStop(t *testing.T) { + createDSNWithClientSessionKeepAlive() + config, err := ParseDSN(dsn) + if err != nil { + t.Fatalf("failed to parse dsn. err: %v", err) + } + driver := SnowflakeDriver{} + db, err := driver.OpenWithConfig(context.Background(), *config) + if err != nil { + t.Fatalf("failed to open with config. config: %v, err: %v", config, err) + } + + conn, ok := db.(*snowflakeConn) + if ok && conn.isHeartbeatNil() { + t.Fatalf("heartbeat should not be nil") + } + + err = db.Close() + if err != nil { + t.Fatalf("should not cause error in Close. err: %v", err) + } + if ok && !conn.isHeartbeatNil() { + t.Fatalf("heartbeat should be nil") + } +} From c38e7738b21401e4a953b18739bbb76431da67fd Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Tue, 10 Oct 2023 10:27:46 -0700 Subject: [PATCH 05/19] add test to heartbeat_test.go --- heartbeat_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/heartbeat_test.go b/heartbeat_test.go index c1023129d..26aa862f4 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -24,6 +24,12 @@ func TestUnitPostHeartbeat(t *testing.T) { t.Fatalf("failed to heartbeat and renew session. err: %v", err) } + heartbeat.restful.FuncPost = postTestError + err = heartbeat.heartbeatMain() + if err == nil { + t.Fatal("should have failed") + } + heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON err = heartbeat.heartbeatMain() if err == nil { From 082627085aa7680809aa0a53fdcd7647847bed24 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Tue, 10 Oct 2023 10:29:29 -0700 Subject: [PATCH 06/19] add int16 test case to converter_test.go --- converter_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/converter_test.go b/converter_test.go index 6a43404f2..c14f005fa 100644 --- a/converter_test.go +++ b/converter_test.go @@ -537,6 +537,70 @@ func TestArrowToValue(t *testing.T) { }, higherPrecision: true, }, + { + logical: "fixed", + physical: "int16", + values: []string{"1.2345", "2.3456"}, + rowType: execResponseRowType{Scale: 4}, + builder: array.NewInt16Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 4) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int16Builder).Append(int16(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 4) + if !ok { + return i + } + srcDec := intToBigFloat(num, 4) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "int16", + values: []string{"1.2345", "2.3456"}, + rowType: execResponseRowType{Scale: 4}, + builder: array.NewInt16Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 4) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int16Builder).Append(int16(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 4) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 4, float64(num)/math.Pow10(int(4))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, + }, { logical: "fixed", physical: "int32", From e5d60f7339108defdd1173e9d18aa87b213b7fa5 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:32:52 -0700 Subject: [PATCH 07/19] add error tests for transaction.go --- transaction_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/transaction_test.go b/transaction_test.go index 5027e47f0..0cb274be6 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -95,3 +95,36 @@ func withRetry(fn func(context.Context, *sql.Conn) error, numAttempts int, timeo return fmt.Errorf("context deadline exceeded, failed after [%d] attempts", numAttempts) } } + +func TestTransactionError(t *testing.T) { + var err error + var tx snowflakeTx + + sr := &snowflakeRestful{ + FuncPostQuery: postQueryFail, + } + tx.sc = &snowflakeConn{ + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + } + tx.ctx = context.Background() + + // test for post query error when executing the txCommand + err = tx.execTxCommand(rollback) + if err == nil { + t.Fatal("should have failed to post query") + } + + // test for invalid txCommand + err = tx.execTxCommand(2) + if err == nil { + t.Fatal("should have failed to execute unsupported transaction command") + } + + // test for bad connection error when snowflakeConn is nil + tx.sc = nil + err = tx.execTxCommand(rollback) + if err == nil { + t.Fatal("should have failed to execute txCommand when connection is nil") + } +} From 0af41a8863e27e75f6fda117e32c40c8ba1e8414 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Tue, 10 Oct 2023 23:16:22 -0700 Subject: [PATCH 08/19] Add tests for auth.go and authokta.go --- authokta_test.go | 39 ++++++++++++++++++++++++++++++++++++++- dsn_test.go | 4 ++++ errors.go | 8 ++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/authokta_test.go b/authokta_test.go index dd1cf7af2..0d42cc452 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -7,6 +7,7 @@ import ( "errors" "net/http" "net/url" + "strconv" "testing" "time" ) @@ -135,6 +136,14 @@ func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]s }, nil } +func postAuthSAMLAuthFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: false, + Code: strconv.Itoa(ErrCodeIdpConnectionError), + Message: "SAML auth failed", + }, nil +} + func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, @@ -177,6 +186,10 @@ func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[ return []byte(`
`), nil } +func getSSOSuccessButWrongPrefixURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { + return []byte(``), nil +} + func TestUnitAuthenticateBySAML(t *testing.T) { authenticator := &url.URL{ Scheme: "https", @@ -203,7 +216,7 @@ func TestUnitAuthenticateBySAML(t *testing.T) { if err == nil { t.Fatal("should have failed.") } - sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL + sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") @@ -215,6 +228,18 @@ func TestUnitAuthenticateBySAML(t *testing.T) { if driverErr.Number != ErrCodeIdpConnectionError { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) } + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL + _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + if err == nil { + t.Fatal("should have failed.") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeIdpConnectionError { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) + } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) @@ -237,4 +262,16 @@ func TestUnitAuthenticateBySAML(t *testing.T) { if err != nil { t.Fatalf("failed. err: %v", err) } + sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL + _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + if err == nil { + t.Fatal("should have failed.") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCodeSSOURLNotMatch { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeSSOURLNotMatch, driverErr.Number) + } } diff --git a/dsn_test.go b/dsn_test.go index 0086e2cf5..b43675ed6 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -608,6 +608,10 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", + err: errFailedToParseAuthenticator(), + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { diff --git a/errors.go b/errors.go index 656ae0f1b..b10f01a31 100644 --- a/errors.go +++ b/errors.go @@ -323,6 +323,14 @@ func errInvalidRegion() *SnowflakeError { } } +// Returned if a DSN includes an invalid authenticator. +func errFailedToParseAuthenticator() *SnowflakeError { + return &SnowflakeError{ + Number: ErrCodeFailedToParseAuthenticator, + Message: "failed to parse an authenticator", + } +} + // Returned if the server side returns an error without meaningful message. func errUnknownError() *SnowflakeError { return &SnowflakeError{ From 5c9497e56340743b14cfbe8020b1ec9b5575c23d Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Oct 2023 00:35:11 -0700 Subject: [PATCH 09/19] add test TestUnitAuthenticateWithConfigOkta for auth.go authenticateWithConfig() --- auth_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/auth_test.go b/auth_test.go index 43123d760..efe6b1bb6 100644 --- a/auth_test.go +++ b/auth_test.go @@ -655,6 +655,39 @@ func TestUnitAuthenticateWithConfigMFA(t *testing.T) { } } +func TestUnitAuthenticateWithConfigOkta(t *testing.T) { + var err error + sr := &snowflakeRestful{ + Protocol: "https", + Host: "abc.com", + Port: 443, + FuncPostAuthSAML: postAuthSAMLAuthSuccess, + FuncPostAuthOKTA: postAuthOKTASuccess, + FuncGetSSO: getSSOSuccess, + FuncPostAuth: postAuthSuccess, + TokenAccessor: getSimpleTokenAccessor(), + } + sc := getDefaultSnowflakeConn() + sc.cfg.Authenticator = AuthTypeOkta + sc.cfg.OktaURL = &url.URL{ + Scheme: "https", + Host: "abc.com", + } + sc.rest = sr + sc.ctx = context.TODO() + + err = authenticateWithConfig(sc) + if err != nil { + t.Fatalf("failed to run. err: %v", err) + } + + sr.FuncPostAuthSAML = postAuthSAMLError + err = authenticateWithConfig(sc) + if err == nil { + t.Fatalf("should have failed.") + } +} + func TestUnitAuthenticateExternalBrowser(t *testing.T) { var err error sr := &snowflakeRestful{ From d0c4ae10e131ca9bac1959ae40a282802c36a5a3 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:09:39 -0700 Subject: [PATCH 10/19] Add test for url parse error --- authokta.go | 2 +- authokta_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/authokta.go b/authokta.go index 14786b158..0a4a0fcb7 100644 --- a/authokta.go +++ b/authokta.go @@ -112,7 +112,7 @@ func authenticateBySAML( if err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } - ssoURL, err = url.Parse(respd.Data.TokenURL) + ssoURL, err = url.Parse(respd.Data.SSOURL) if err != nil { return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) } diff --git a/authokta_test.go b/authokta_test.go index 0d42cc452..3c77c43e5 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -123,6 +123,10 @@ func TestUnitGetSSO(t *testing.T) { if err != nil { t.Fatalf("failed to get HTML content. err: %v", err) } + _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0) + if err == nil { + t.Fatal("should have failed to parse URL.") + } } func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { @@ -155,6 +159,28 @@ func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful }, nil } +func postAuthSAMLAuthSuccessButInvalidTokenURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: true, + Message: "", + Data: authResponseMain{ + TokenURL: "invalid!@url$%^", + SSOURL: "https://abc.com/sso", + }, + }, nil +} + +func postAuthSAMLAuthSuccessButInvalidSSOURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: true, + Message: "", + Data: authResponseMain{ + TokenURL: "https://abc.com/token", + SSOURL: "invalid!@url$%^", + }, + }, nil +} + func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, @@ -240,6 +266,16 @@ func TestUnitAuthenticateBySAML(t *testing.T) { if driverErr.Number != ErrCodeIdpConnectionError { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) } + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL + _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + if err == nil { + t.Fatal("should have failed.") + } + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL + _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + if err == nil { + t.Fatal("should have failed.") + } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) From 2b0f5c49e3886e766de89e682e394222337bc8fb Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Oct 2023 17:17:02 -0700 Subject: [PATCH 11/19] remove use role sysadmin for put get tests --- file_transfer_agent_test.go | 3 --- put_get_test.go | 6 ------ 2 files changed, 9 deletions(-) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 163fe3ce1..19b374b48 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -63,9 +63,6 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { f.Close() runDBTest(t, func(dbt *DBTest) { - if _, err = dbt.exec("use role sysadmin"); err != nil { - t.Skip("snowflake admin account not accessible") - } dbt.mustExec("rm @~/test_get") sqlText := fmt.Sprintf("put file://%v @~/test_get", testData) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") diff --git a/put_get_test.go b/put_get_test.go index 730efbcbb..d18ce9631 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -268,9 +268,6 @@ func TestPutWithAutoCompressFalse(t *testing.T) { defer f.Close() runDBTest(t, func(dbt *DBTest) { - if _, err = dbt.exec("use role sysadmin"); err != nil { - t.Skip("snowflake admin account not accessible") - } dbt.mustExec("rm @~/test_put_uncompress_file") sqlText := fmt.Sprintf("put file://%v @~/test_put_uncompress_file auto_compress=FALSE", testData) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") @@ -308,9 +305,6 @@ func TestPutOverwrite(t *testing.T) { f.Close() runDBTest(t, func(dbt *DBTest) { - if _, err = dbt.exec("use role sysadmin"); err != nil { - t.Skip("snowflake admin account not accessible") - } dbt.mustExec("rm @~/test_put_overwrite") f, _ = os.Open(testData) From 456406e9a55c645ab096c03f3a3b5b86af8b2e92 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Oct 2023 21:48:00 -0700 Subject: [PATCH 12/19] add test for file_transfer_agent.go --- file_transfer_agent_test.go | 93 +++++++++++++++++++++++++++++++++++++ put_get_test.go | 3 ++ 2 files changed, 96 insertions(+) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 19b374b48..e608c9544 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -658,3 +658,96 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) { t.Fatalf("should not upload file as temporary directory is not readable") } } + +func TestUploadDownloadOneFileRequireCompress(t *testing.T) { + testUploadDownloadOneFile(t, false) +} + +func TestUploadDownloadOneFileRequireCompressStream(t *testing.T) { + testUploadDownloadOneFile(t, true) +} + +func testUploadDownloadOneFile(t *testing.T, isStream bool) { + tmpDir, err := os.MkdirTemp("", "data") + if err != nil { + t.Fatalf("cannot create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) + uploadFile := filepath.Join(tmpDir, "data.txt") + f, err := os.Create(uploadFile) + if err != nil { + t.Error(err) + } + f.WriteString("test1,test2\ntest3,test4\n") + f.Close() + + uploadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + dstFileName: "data.txt.gz", + srcFileName: uploadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + requireCompress: true, + } + + downloadFile := filepath.Join(tmpDir, "download.txt") + downloadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + srcFileName: "data.txt.gz", + dstFileName: downloadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } + + sfa := snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{ + TmpDirPath: tmpDir, + }, + }, + stageLocationType: local, + } + + if isStream { + fileStream, _ := os.Open(uploadFile) + ctx := WithFileStream(context.Background(), fileStream) + uploadMeta.srcStream = getFileStream(ctx) + } + + _, err = sfa.uploadOneFile(uploadMeta) + if err != nil { + t.Fatal(err) + } + if uploadMeta.resStatus != uploaded { + t.Fatalf("failed to upload file") + } + + _, err = sfa.downloadOneFile(downloadMeta) + if err != nil { + t.Fatal(err) + } + if downloadMeta.resStatus != downloaded { + t.Fatalf("failed to download file") + } + defer os.Remove("download.txt") +} diff --git a/put_get_test.go b/put_get_test.go index d18ce9631..6c4403630 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -305,6 +305,9 @@ func TestPutOverwrite(t *testing.T) { f.Close() runDBTest(t, func(dbt *DBTest) { + if _, err = dbt.exec("use role sysadmin"); err != nil { + t.Skip("snowflake admin account not accessible") + } dbt.mustExec("rm @~/test_put_overwrite") f, _ = os.Open(testData) From ad2e93c0fd16384d41ba5d2a63640d705e44ae05 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Oct 2023 22:25:49 -0700 Subject: [PATCH 13/19] revert use role sysadmin test changes --- file_transfer_agent_test.go | 3 +++ put_get_test.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index e608c9544..7a591d414 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -63,6 +63,9 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { f.Close() runDBTest(t, func(dbt *DBTest) { + if _, err = dbt.exec("use role sysadmin"); err != nil { + t.Skip("snowflake admin account not accessible") + } dbt.mustExec("rm @~/test_get") sqlText := fmt.Sprintf("put file://%v @~/test_get", testData) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") diff --git a/put_get_test.go b/put_get_test.go index 6c4403630..730efbcbb 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -268,6 +268,9 @@ func TestPutWithAutoCompressFalse(t *testing.T) { defer f.Close() runDBTest(t, func(dbt *DBTest) { + if _, err = dbt.exec("use role sysadmin"); err != nil { + t.Skip("snowflake admin account not accessible") + } dbt.mustExec("rm @~/test_put_uncompress_file") sqlText := fmt.Sprintf("put file://%v @~/test_put_uncompress_file auto_compress=FALSE", testData) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") From 6f4d85bd33b0c2edcd404e536471a39c05ffe1bb Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Fri, 20 Oct 2023 14:54:47 -0700 Subject: [PATCH 14/19] fix authokta_test.go to use context.Background --- authokta_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/authokta_test.go b/authokta_test.go index c5b6b0f04..747312e05 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -123,7 +123,7 @@ func TestUnitGetSSO(t *testing.T) { if err != nil { t.Fatalf("failed to get HTML content. err: %v", err) } - _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0) + _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0) if err == nil { t.Fatal("should have failed to parse URL.") } @@ -255,7 +255,7 @@ func TestUnitAuthenticateBySAML(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } @@ -267,12 +267,12 @@ func TestUnitAuthenticateBySAML(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } @@ -299,7 +299,7 @@ func TestUnitAuthenticateBySAML(t *testing.T) { t.Fatalf("failed. err: %v", err) } sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } From c2f05b5e74bb43b5c08cec4faf5cd36885909651 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Fri, 20 Oct 2023 14:56:08 -0700 Subject: [PATCH 15/19] revert the err checks from separate lines to single line --- arrow_chunk.go | 3 +-- async.go | 3 +-- authexternalbrowser.go | 3 +-- authokta.go | 6 ++---- azure_storage_client.go | 3 +-- bind_uploader.go | 3 +-- chunk.go | 3 +-- connection.go | 23 +++++++------------- connection_util.go | 11 ++++------ converter.go | 3 +-- driver.go | 6 ++---- dsn.go | 6 ++---- encrypt_util.go | 6 ++---- errors.go | 3 +-- file_transfer_agent.go | 45 +++++++++++++-------------------------- gcs_storage_client.go | 9 +++----- local_storage_client.go | 9 +++----- monitoring.go | 9 +++----- multistatement.go | 3 +-- ocsp.go | 12 ++++------- restful.go | 18 ++++++---------- result.go | 6 ++---- rows.go | 35 +++++++++++------------------- s3_storage_client.go | 5 ++--- secure_storage_manager.go | 18 ++++++---------- storage_client.go | 12 ++++------- telemetry.go | 6 ++---- 27 files changed, 91 insertions(+), 178 deletions(-) diff --git a/arrow_chunk.go b/arrow_chunk.go index 2570b2754..15851a80f 100644 --- a/arrow_chunk.go +++ b/arrow_chunk.go @@ -36,8 +36,7 @@ func (arc *arrowResultChunk) decodeArrowChunk(rowType []execResponseRowType, hig for colIdx, col := range columns { values := make([]snowflakeValue, numRows) - err := arrowToValue(values, rowType[colIdx], col, arc.loc, highPrec) - if err != nil { + if err := arrowToValue(values, rowType[colIdx], col, arc.loc, highPrec); err != nil { return nil, err } diff --git a/async.go b/async.go index d2db50ce8..47da1b7b4 100644 --- a/async.go +++ b/async.go @@ -132,8 +132,7 @@ func (sr *snowflakeRestful) getAsync( rows.sc = sc rows.queryID = respd.Data.QueryID if isMultiStmt(&respd.Data) { - err = sc.handleMultiQuery(ctx, respd.Data, rows) - if err != nil { + if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil { rows.errChannel <- err return err } diff --git a/authexternalbrowser.go b/authexternalbrowser.go index 56df0feeb..a8d966cef 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -226,8 +226,7 @@ func doAuthenticateByExternalBrowser( return authenticateByExternalBrowserResult{nil, nil, err} } - err = openBrowser(idpURL) - if err != nil { + if err = openBrowser(idpURL); err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } diff --git a/authokta.go b/authokta.go index 0a4a0fcb7..3e3f3c518 100644 --- a/authokta.go +++ b/authokta.go @@ -108,12 +108,10 @@ func authenticateBySAML( logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL") var tokenURL *url.URL var ssoURL *url.URL - tokenURL, err = url.Parse(respd.Data.TokenURL) - if err != nil { + if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } - ssoURL, err = url.Parse(respd.Data.SSOURL) - if err != nil { + if ssoURL, err = url.Parse(respd.Data.TokenURL); err != nil { return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { diff --git a/azure_storage_client.go b/azure_storage_client.go index bc2091415..0db3a4cb3 100644 --- a/azure_storage_client.go +++ b/azure_storage_client.go @@ -106,8 +106,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str _, ok = metadata["Encryptiondata"] if ok { - err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData) - if err != nil { + if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil { return nil, err } } diff --git a/bind_uploader.go b/bind_uploader.go index eb183737f..414bbb83f 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -66,8 +66,7 @@ func (bu *bindUploader) uploadStreamInternal( dstFileName int, compressData bool) ( *execResponse, error) { - err := bu.createStageIfNeeded() - if err != nil { + if err := bu.createStageIfNeeded(); err != nil { return nil, err } stageName := bu.stagePath diff --git a/chunk.go b/chunk.go index 531f15254..4708f6282 100644 --- a/chunk.go +++ b/chunk.go @@ -146,8 +146,7 @@ func (lcd *largeChunkDecoder) decodeString() (string, error) { if c == '"' { break } else if c == '\\' { - err := lcd.decodeEscaped() - if err != nil { + if err := lcd.decodeEscaped(); err != nil { return "", err } } else if c < ' ' { diff --git a/connection.go b/connection.go index 6837a72c5..c9d760327 100644 --- a/connection.go +++ b/connection.go @@ -108,8 +108,7 @@ func (sc *snowflakeConn) exec( // handle bindings, if required requestID := getOrGenerateRequestIDFromContext(ctx) if len(bindings) > 0 { - err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req) - if err != nil { + if err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req); err != nil { return nil, err } } @@ -237,9 +236,8 @@ func (sc *snowflakeConn) BeginTx( return nil, driver.ErrBadConn } isDesc := isDescribeOnly(ctx) - _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ - false /* isInternal */, isDesc, nil) - if err != nil { + if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ + false /* isInternal */, isDesc, nil); err != nil { return nil, err } return &snowflakeTx{sc, ctx}, nil @@ -261,8 +259,7 @@ func (sc *snowflakeConn) Close() (err error) { defer sc.cleanup() if sc.cfg != nil && !sc.cfg.KeepSessionAlive { - err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout) - if err != nil { + if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { logger.Error(err) } } @@ -405,8 +402,7 @@ func (sc *snowflakeConn) queryContextInternal( if isMultiStmt(&data.Data) { // handleMultiQuery is responsible to fill rows with childResults - err = sc.handleMultiQuery(ctx, data.Data, rows) - if err != nil { + if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { return nil, err } } else { @@ -541,8 +537,7 @@ type wrapReader struct { func (w *wrapReader) Close() error { if cl, ok := w.Reader.(io.ReadCloser); ok { - err := cl.Close() - if err != nil { + if err := cl.Close(); err != nil { return err } } @@ -618,8 +613,7 @@ func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) erro // to ensure no leaked memory. func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { if asb.rr == nil { - err := asb.downloadChunkStreamHelper(ctx) - if err != nil { + if err := asb.downloadChunkStreamHelper(ctx); err != nil { return nil, err } } @@ -756,8 +750,7 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err st = sc.cfg.Transporter } if strings.HasSuffix(sc.cfg.Host, privateLinkSuffix) { - err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host) - if err != nil { + if err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host); err != nil { return nil, err } } else { diff --git a/connection_util.go b/connection_util.go index 31e7d5a76..17ca1b975 100644 --- a/connection_util.go +++ b/connection_util.go @@ -110,11 +110,10 @@ func (sc *snowflakeConn) processFileTransfer( if sfa.options.MultiPartThreshold == 0 { sfa.options.MultiPartThreshold = dataSizeThreshold } - err := sfa.execute() - if err != nil { + if err := sfa.execute(); err != nil { return nil, err } - data, err = sfa.result() + data, err := sfa.result() if err != nil { return nil, err } @@ -290,14 +289,12 @@ func populateChunkDownloader( func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host) logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) - err := os.Setenv(cacheServerURLEnv, ocspCacheServer) - if err != nil { + if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil { return err } ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) - err = os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate) - if err != nil { + if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil { return err } return nil diff --git a/converter.go b/converter.go index 2f3e8c374..88f64baa6 100644 --- a/converter.go +++ b/converter.go @@ -77,8 +77,7 @@ func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType { if len(t) != 1 { return unSupportedType } - _, err := dataTypeMode(t) - if err != nil { + if _, err := dataTypeMode(t); err != nil { return unSupportedType } return changeType diff --git a/driver.go b/driver.go index 136d2ff2b..6a565be4e 100644 --- a/driver.go +++ b/driver.go @@ -28,8 +28,7 @@ func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { // OpenWithConfig creates a new connection with the given Config. func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (driver.Conn, error) { - err := config.Validate() - if err != nil { + if err := config.Validate(); err != nil { return nil, err } if config.Tracing != "" { @@ -41,8 +40,7 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri return nil, err } - err = authenticateWithConfig(sc) - if err != nil { + if err = authenticateWithConfig(sc); err != nil { return nil, err } sc.connectionTelemetry(&config) diff --git a/dsn.go b/dsn.go index 66b70fe17..92fd472e0 100644 --- a/dsn.go +++ b/dsn.go @@ -107,8 +107,7 @@ type Config struct { // A driver client may call it manually, but it is also called during opening first connection. func (c *Config) Validate() error { if c.TmpDirPath != "" { - _, err := os.Stat(c.TmpDirPath) - if err != nil { + if _, err := os.Stat(c.TmpDirPath); err != nil { return err } } @@ -561,8 +560,7 @@ func parseUserPassword(posAt int, dsn string) (user, password string) { func parseParams(cfg *Config, posQuestion int, dsn string) (err error) { for j := posQuestion + 1; j < len(dsn); j++ { if dsn[j] == '?' { - err = parseDSNParams(cfg, dsn[j+1:]) - if err != nil { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } break diff --git a/encrypt_util.go b/encrypt_util.go index fce9049d5..08179891d 100644 --- a/encrypt_util.go +++ b/encrypt_util.go @@ -104,8 +104,7 @@ func encryptStream( // encrypt key with ECB fileKey = padBytesLength(fileKey, block.BlockSize()) encryptedFileKey := make([]byte, len(fileKey)) - err = encryptECB(encryptedFileKey, fileKey, decodedKey) - if err != nil { + if err = encryptECB(encryptedFileKey, fileKey, decodedKey); err != nil { return nil, err } @@ -216,8 +215,7 @@ func decryptFile( // decrypt file key decryptedKey := make([]byte, len(keyBytes)) - err = decryptECB(decryptedKey, keyBytes, decodedKey) - if err != nil { + if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil { return "", err } decryptedKey, err = paddingTrim(decryptedKey) diff --git a/errors.go b/errors.go index 3c029f107..486565857 100644 --- a/errors.go +++ b/errors.go @@ -74,8 +74,7 @@ func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *teleme func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError { data := se.generateTelemetryExceptionData() - err := se.sendExceptionTelemetry(sc, data) - if err != nil { + if err := se.sendExceptionTelemetry(sc, data); err != nil { logger.Debugf("failed to log to telemetry: %v", data) } return se diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 79a878c6a..9e3a64697 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -129,31 +129,26 @@ type snowflakeFileTransferAgent struct { func (sfa *snowflakeFileTransferAgent) execute() error { var err error - err = sfa.parseCommand() - if err != nil { + if err = sfa.parseCommand(); err != nil { return err } - err = sfa.initFileMetadata() - if err != nil { + if err = sfa.initFileMetadata(); err != nil { return err } if sfa.commandType == uploadCommand { - err = sfa.processFileCompressionType() - if err != nil { + if err = sfa.processFileCompressionType(); err != nil { return err } } - err = sfa.transferAccelerateConfig() - if err != nil { + if err = sfa.transferAccelerateConfig(); err != nil { return err } if sfa.commandType == downloadCommand { if _, err = os.Stat(sfa.localLocation); os.IsNotExist(err) { - err = os.MkdirAll(sfa.localLocation, os.ModePerm) - if err != nil { + if err = os.MkdirAll(sfa.localLocation, os.ModePerm); err != nil { return err } } @@ -161,15 +156,13 @@ func (sfa *snowflakeFileTransferAgent) execute() error { if sfa.stageLocationType == local { if _, err = os.Stat(sfa.stageInfo.Location); os.IsNotExist(err) { - err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm) - if err != nil { + if err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm); err != nil { return err } } } - err = sfa.updateFileMetadataWithPresignedURL() - if err != nil { + if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { return err } @@ -197,13 +190,11 @@ func (sfa *snowflakeFileTransferAgent) execute() error { } if sfa.commandType == uploadCommand { - err = sfa.upload(largeFileMetas, smallFileMetas) - if err != nil { + if err = sfa.upload(largeFileMetas, smallFileMetas); err != nil { return err } } else { - err = sfa.download(smallFileMetas) - if err != nil { + if err = sfa.download(smallFileMetas); err != nil { return err } } @@ -262,8 +253,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error { if err != nil { return err } - fi, err := os.Stat(sfa.localLocation) - if err != nil || !fi.IsDir() { + if fi, err := os.Stat(sfa.localLocation); err != nil || !fi.IsDir() { return (&SnowflakeError{ Number: ErrLocalPathNotDirectory, SQLState: sfa.data.SQLState, @@ -693,15 +683,13 @@ func (sfa *snowflakeFileTransferAgent) upload( if len(smallFileMetadata) > 0 { logger.Infof("uploading %v small files", len(smallFileMetadata)) - err = sfa.uploadFilesParallel(smallFileMetadata) - if err != nil { + if err = sfa.uploadFilesParallel(smallFileMetadata); err != nil { return err } } if len(largeFileMetadata) > 0 { logger.Infof("uploading %v large files", len(largeFileMetadata)) - err = sfa.uploadFilesSequential(largeFileMetadata) - if err != nil { + if err = sfa.uploadFilesSequential(largeFileMetadata); err != nil { return err } } @@ -720,8 +708,7 @@ func (sfa *snowflakeFileTransferAgent) download( } logger.WithContext(sfa.sc.ctx).Infof("downloading %v files", len(fileMetadata)) - err = sfa.downloadFilesParallel(fileMetadata) - if err != nil { + if err = sfa.downloadFilesParallel(fileMetadata); err != nil { return err } return nil @@ -882,8 +869,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM } client := sfa.getStorageClient(sfa.stageLocationType) - err = client.uploadOneFileWithRetry(meta) - if err != nil { + if err = client.uploadOneFileWithRetry(meta); err != nil { return meta, err } return meta, nil @@ -972,8 +958,7 @@ func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fil meta.tmpDir = tmpDir defer os.RemoveAll(tmpDir) // cleanup client := sfa.getStorageClient(sfa.stageLocationType) - err = client.downloadOneFile(meta) - if err != nil { + if err = client.downloadOneFile(meta); err != nil { meta.dstFileSize = -1 if !meta.resStatus.isSet() { meta.resStatus = errStatus diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 28b71804f..5d9dac9a8 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -249,8 +249,7 @@ func (util *snowflakeGcsClient) uploadFile( meta.gcsFileHeaderDigest = gcsHeaders[gcsFileHeaderDigest] meta.gcsFileHeaderContentLength = meta.uploadSize - err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &encryptMeta) - if err != nil { + if err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &encryptMeta); err != nil { return err } meta.gcsFileHeaderEncryptionMeta = encryptMeta @@ -320,16 +319,14 @@ func (util *snowflakeGcsClient) nativeDownloadFile( return err } defer f.Close() - _, err = io.Copy(f, resp.Body) - if err != nil { + if _, err = io.Copy(f, resp.Body); err != nil { return err } var encryptMeta encryptMetadata if resp.Header.Get(gcsMetadataEncryptionDataProp) != "" { var encryptData *encryptionData - err = json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData) - if err != nil { + if err = json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData); err != nil { return err } if encryptData != nil { diff --git a/local_storage_client.go b/local_storage_client.go index 1b6702221..2ae072b63 100644 --- a/local_storage_client.go +++ b/local_storage_client.go @@ -62,8 +62,7 @@ func (util *localUtil) uploadOneFileWithRetry(meta *fileMetadata) error { break } - _, err = output.Write(data) - if err != nil { + if _, err = output.Write(data); err != nil { return err } } @@ -92,8 +91,7 @@ func (util *localUtil) downloadOneFile(meta *fileMetadata) error { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { - err = os.MkdirAll(baseDir, os.ModePerm) - if err != nil { + if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { return err } } @@ -102,8 +100,7 @@ func (util *localUtil) downloadOneFile(meta *fileMetadata) error { if err != nil { return err } - err = os.WriteFile(fullDstFileName, data, readWriteFileMode) - if err != nil { + if err = os.WriteFile(fullDstFileName, data, readWriteFileMode); err != nil { return err } fi, err := os.Stat(fullDstFileName) diff --git a/monitoring.go b/monitoring.go index 06535718e..07b17c0aa 100644 --- a/monitoring.go +++ b/monitoring.go @@ -146,8 +146,7 @@ func (sc *snowflakeConn) checkQueryStatus( } defer res.Body.Close() var statusResp = statusResponse{} - err = json.NewDecoder(res.Body).Decode(&statusResp) - if err != nil { + if err = json.NewDecoder(res.Body).Decode(&statusResp); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } @@ -222,8 +221,7 @@ func (sc *snowflakeConn) getQueryResultResp( } defer res.Body.Close() var respd *execResponse - err = json.NewDecoder(res.Body).Decode(&respd) - if err != nil { + if err = json.NewDecoder(res.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } @@ -264,8 +262,7 @@ func (sc *snowflakeConn) buildRowsForRunningQuery( rows := new(snowflakeRows) rows.sc = sc rows.queryID = qid - err := sc.rowsForRunningQuery(ctx, qid, rows) - if err != nil { + if err := sc.rowsForRunningQuery(ctx, qid, rows); err != nil { return nil, err } rows.ChunkDownloader.start() diff --git a/multistatement.go b/multistatement.go index af579681c..ce9d9910b 100644 --- a/multistatement.go +++ b/multistatement.go @@ -98,8 +98,7 @@ func (sc *snowflakeConn) handleMultiQuery( } childResults := getChildResults(data.ResultIDs, data.ResultTypes) for _, child := range childResults { - err := sc.rowsForRunningQuery(ctx, child.id, rows) - if err != nil { + if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil { return err } } diff --git a/ocsp.go b/ocsp.go index b83502215..297a999db 100644 --- a/ocsp.go +++ b/ocsp.go @@ -920,13 +920,11 @@ func writeOCSPCacheFile() { logger.Debugf("other process locks the cache file. %v. ignored.\n", cacheLockFileName) return } - err = os.Remove(cacheLockFileName) - if err != nil { + if err = os.Remove(cacheLockFileName); err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } - err = os.Mkdir(cacheLockFileName, 0600) - if err != nil { + if err = os.Mkdir(cacheLockFileName, 0600); err != nil { logger.Debugf("failed to create lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } @@ -949,8 +947,7 @@ func writeOCSPCacheFile() { logger.Debugf("failed to convert OCSP Response cache to JSON. ignored.") return } - err = os.WriteFile(cacheFileName, j, 0644) - if err != nil { + if err = os.WriteFile(cacheFileName, j, 0644); err != nil { logger.Debugf("failed to write OCSP Response cache. err: %v. ignored.\n", err) } } @@ -1009,8 +1006,7 @@ func createOCSPCacheDir() { } if _, err := os.Stat(cacheDir); os.IsNotExist(err) { - err = os.MkdirAll(cacheDir, os.ModePerm) - if err != nil { + if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { logger.Debugf("failed to create cache directory. %v, err: %v. ignored\n", cacheDir, err) } } diff --git a/restful.go b/restful.go index 1687c6041..6b10dd4b3 100644 --- a/restful.go +++ b/restful.go @@ -252,14 +252,12 @@ func postRestfulQueryHelper( if resp.StatusCode == http.StatusOK { logger.WithContext(ctx).Infof("postQuery: resp: %v", resp) var respd execResponse - err = json.NewDecoder(resp.Body).Decode(&respd) - if err != nil { + if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } if respd.Code == sessionExpiredCode { - err = sr.renewExpiredSessionToken(ctx, timeout, token) - if err != nil { + if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { return nil, err } return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg) @@ -300,8 +298,7 @@ func postRestfulQueryHelper( return nil, err } if respd.Code == sessionExpiredCode { - err = sr.renewExpiredSessionToken(ctx, timeout, token) - if err != nil { + if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { return nil, err } isSessionRenewed = true @@ -345,8 +342,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati defer resp.Body.Close() if resp.StatusCode == http.StatusOK { var respd renewSessionResponse - err = json.NewDecoder(resp.Body).Decode(&respd) - if err != nil { + if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } @@ -477,15 +473,13 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time defer resp.Body.Close() if resp.StatusCode == http.StatusOK { var respd cancelQueryResponse - err = json.NewDecoder(resp.Body).Decode(&respd) - if err != nil { + if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } ctxRetry := getCancelRetry(ctx) if !respd.Success && respd.Code == sessionExpiredCode { - err = sr.FuncRenewSession(ctx, sr, timeout) - if err != nil { + if err = sr.FuncRenewSession(ctx, sr, timeout); err != nil { return err } return sr.FuncCancelQuery(ctx, sr, requestID, timeout) diff --git a/result.go b/result.go index 707b4c026..e08f41902 100644 --- a/result.go +++ b/result.go @@ -30,16 +30,14 @@ type snowflakeResult struct { } func (res *snowflakeResult) LastInsertId() (int64, error) { - err := res.waitForAsyncExecStatus() - if err != nil { + if err := res.waitForAsyncExecStatus(); err != nil { return -1, err } return res.insertID, nil } func (res *snowflakeResult) RowsAffected() (int64, error) { - err := res.waitForAsyncExecStatus() - if err != nil { + if err := res.waitForAsyncExecStatus(); err != nil { return -1, err } return res.affectedRows, nil diff --git a/rows.go b/rows.go index 43d27f93a..3d3fcbb0f 100644 --- a/rows.go +++ b/rows.go @@ -71,9 +71,8 @@ type chunkError struct { Error error } -func (rows *snowflakeRows) Close() error { - err := rows.waitForAsyncQueryStatus() - if err != nil { +func (rows *snowflakeRows) Close() (err error) { + if err := rows.waitForAsyncQueryStatus(); err != nil { return err } logger.WithContext(rows.sc.ctx).Debugln("Rows.Close") @@ -82,8 +81,7 @@ func (rows *snowflakeRows) Close() error { // ColumnTypeDatabaseTypeName returns the database column name. func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return err.Error() } return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type) @@ -91,8 +89,7 @@ func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { // ColumnTypeLength returns the length of the column func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return 0, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -106,8 +103,7 @@ func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { } func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return false, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -117,8 +113,7 @@ func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { } func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return 0, 0, false } rowType := rows.ChunkDownloader.getRowType() @@ -137,8 +132,7 @@ func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale } func (rows *snowflakeRows) Columns() []string { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return make([]string, 0) } logger.Debug("Rows.Columns") @@ -150,8 +144,7 @@ func (rows *snowflakeRows) Columns() []string { } func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return nil } return snowflakeTypeToGo( @@ -171,8 +164,7 @@ func (rows *snowflakeRows) GetStatus() queryStatus { func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) { // Wait for all arrow batches before fetching. // Otherwise, a panic error "invalid memory address or nil pointer dereference" will be thrown. - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return nil, err } @@ -180,8 +172,7 @@ func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) { } func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { - err = rows.waitForAsyncQueryStatus() - if err != nil { + if err = rows.waitForAsyncQueryStatus(); err != nil { return err } row, err := rows.ChunkDownloader.next() @@ -211,16 +202,14 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { } func (rows *snowflakeRows) HasNextResultSet() bool { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return false } return rows.ChunkDownloader.hasNextResultSet() } func (rows *snowflakeRows) NextResultSet() error { - err := rows.waitForAsyncQueryStatus() - if err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil { return err } if len(rows.ChunkDownloader.getChunkMetas()) == 0 { diff --git a/s3_storage_client.go b/s3_storage_client.go index c655f3fdb..ed3bca59a 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -230,11 +230,10 @@ func (util *snowflakeS3Client) nativeDownloadFile( if meta.mockDownloader != nil { downloader = meta.mockDownloader } - _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ + if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ Bucket: s3Obj.Bucket, Key: s3Obj.Key, - }) - if err != nil { + }); err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == expiredToken { diff --git a/secure_storage_manager.go b/secure_storage_manager.go index 1c3146c0d..9b83a2be7 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -52,8 +52,7 @@ func createCredentialCacheDir() { } if _, err := os.Stat(credCacheDir); os.IsNotExist(err) { - err = os.MkdirAll(credCacheDir, os.ModePerm) - if err != nil { + if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil { logger.Debugf("Failed to create cache directory. %v, err: %v. ignored\n", credCacheDir, err) } } @@ -76,8 +75,7 @@ func setCredential(sc *snowflakeConn, credType, token string) { Key: target, Data: []byte(token), } - err := ring.Set(item) - if err != nil { + if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to Windows credential manager. Err: %v", err) } } else if runtime.GOOS == "darwin" { @@ -90,8 +88,7 @@ func setCredential(sc *snowflakeConn, credType, token string) { Key: account, Data: []byte(token), } - err := ring.Set(item) - if err != nil { + if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to keychain. Err: %v", err) } } else if runtime.GOOS == "linux" { @@ -259,21 +256,18 @@ func writeTemporaryCacheFile(input []byte) { logger.Debugf("other process locks the cache file. %v. ignored.\n", credCache) return } - err = os.Remove(credCacheLockFileName) - if err != nil { + if err = os.Remove(credCacheLockFileName); err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) return } - err = os.Mkdir(credCacheLockFileName, 0600) - if err != nil { + if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) return } } defer os.RemoveAll(credCacheLockFileName) - err = os.WriteFile(credCache, input, 0644) - if err != nil { + if err = os.WriteFile(credCache, input, 0644); err != nil { logger.Debugf("Failed to write the cache file. File: %v err: %v.", credCache, err) } } diff --git a/storage_client.go b/storage_client.go index e489e23f7..a7385c278 100644 --- a/storage_client.go +++ b/storage_client.go @@ -133,8 +133,7 @@ func (rsu *remoteStorageUtil) uploadOneFileWithRetry(meta *fileMetadata) error { retryOuter := true for i := 0; i < 10; i++ { // retry - err := rsu.uploadOneFile(meta) - if err != nil { + if err := rsu.uploadOneFile(meta); err != nil { return err } retryInner := true @@ -185,8 +184,7 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { - err = os.MkdirAll(baseDir, os.ModePerm) - if err != nil { + if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { return err } } @@ -204,8 +202,7 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { var lastErr error maxRetry := defaultMaxRetry for retry := 0; retry < maxRetry; retry++ { - err = utilClass.nativeDownloadFile(meta, fullDstFileName, maxConcurrency) - if err != nil { + if err = utilClass.nativeDownloadFile(meta, fullDstFileName, maxConcurrency); err != nil { return err } if meta.resStatus == downloaded { @@ -221,8 +218,7 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { if err != nil { return err } - err = os.Rename(tmpDstFileName, fullDstFileName) - if err != nil { + if err = os.Rename(tmpDstFileName, fullDstFileName); err != nil { return err } } diff --git a/telemetry.go b/telemetry.go index 0ce60ab64..911aba229 100644 --- a/telemetry.go +++ b/telemetry.go @@ -56,8 +56,7 @@ func (st *snowflakeTelemetry) addLog(data *telemetryData) error { st.logs = append(st.logs, data) st.mutex.Unlock() if len(st.logs) >= st.flushSize { - err := st.sendBatch() - if err != nil { + if err := st.sendBatch(); err != nil { return err } } @@ -112,8 +111,7 @@ func (st *snowflakeTelemetry) sendBatch() error { return err } var respd telemetryResponse - err = json.NewDecoder(resp.Body).Decode(&respd) - if err != nil { + if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.Info(err) st.enabled = false return err From 91cbfa9d6cc1bcc954979705494862a6c7cc28a7 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Sat, 21 Oct 2023 21:48:29 -0700 Subject: [PATCH 16/19] fix SSO URL parsing --- authokta.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authokta.go b/authokta.go index 3e3f3c518..1422688f5 100644 --- a/authokta.go +++ b/authokta.go @@ -111,7 +111,7 @@ func authenticateBySAML( if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } - if ssoURL, err = url.Parse(respd.Data.TokenURL); err != nil { + if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil { return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { From 171e2af9e213f9457089c5130279ccfb76a505e1 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Thu, 26 Oct 2023 21:01:50 -0700 Subject: [PATCH 17/19] use assert functions in tests and make updates based on comments --- auth_test.go | 11 +++---- authokta.go | 2 +- authokta_test.go | 59 ++++++++++++++++--------------------- connection_util.go | 5 +--- connector_test.go | 4 +-- file_transfer_agent_test.go | 8 ++--- transaction_test.go | 31 +++++++++---------- 7 files changed, 49 insertions(+), 71 deletions(-) diff --git a/auth_test.go b/auth_test.go index 70688877c..007a829af 100644 --- a/auth_test.go +++ b/auth_test.go @@ -675,18 +675,15 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) { Host: "abc.com", } sc.rest = sr - sc.ctx = context.TODO() + sc.ctx = context.Background() err = authenticateWithConfig(sc) - if err != nil { - t.Fatalf("failed to run. err: %v", err) - } + assertNilF(t, err, "expected to have no error.") sr.FuncPostAuthSAML = postAuthSAMLError err = authenticateWithConfig(sc) - if err == nil { - t.Fatalf("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to get SAML response") } func TestUnitAuthenticateExternalBrowser(t *testing.T) { diff --git a/authokta.go b/authokta.go index fbd88a542..818753af8 100644 --- a/authokta.go +++ b/authokta.go @@ -112,7 +112,7 @@ func authenticateBySAML( return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil { - return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) + return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { return nil, &SnowflakeError{ diff --git a/authokta_test.go b/authokta_test.go index 747312e05..8adb6466b 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -234,19 +234,17 @@ func TestUnitAuthenticateBySAML(t *testing.T) { } var err error _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to get SAML response") + sr.FuncPostAuthSAML = postAuthSAMLAuthFail _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "strconv.Atoi: parsing \"\": invalid syntax") + sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) @@ -256,9 +254,7 @@ func TestUnitAuthenticateBySAML(t *testing.T) { } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok = err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) @@ -268,41 +264,38 @@ func TestUnitAuthenticateBySAML(t *testing.T) { } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to parse token URL. invalid!@url$%^") + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to parse SSO URL. invalid!@url$%^") + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthOKTA.") + assertEqualE(t, err.Error(), "failed to get SAML response") + sr.FuncPostAuthOKTA = postAuthOKTASuccess sr.FuncGetSSO = getSSOError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncGetSSO.") + assertEqualE(t, err.Error(), "failed to get SSO html") + sr.FuncGetSSO = getSSOSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncGetSSO.") + assertHasPrefixE(t, err.Error(), "failed to find action field in HTML response") + sr.FuncGetSSO = getSSOSuccess _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err != nil { - t.Fatalf("failed. err: %v", err) - } + assertNilF(t, err, "should have succeeded at FuncGetSSO.") + sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncGetSSO.") driverErr, ok = err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) diff --git a/connection_util.go b/connection_util.go index 17ca1b975..384bc113a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -24,10 +24,7 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { } func (sc *snowflakeConn) isHeartbeatNil() bool { - if sc.rest != nil { - return sc.rest.HeartBeat == nil - } - return true + return sc.rest == nil || sc.rest.HeartBeat == nil } func (sc *snowflakeConn) startHeartBeat() { diff --git a/connector_test.go b/connector_test.go index fae283b74..1a795c767 100644 --- a/connector_test.go +++ b/connector_test.go @@ -60,9 +60,7 @@ func TestConnectorWithMissingConfig(t *testing.T) { connector := NewConnector(&mock, config) _, err := connector.Connect(context.Background()) - if err == nil { - t.Fatalf("should have failed") - } + assertNotNilF(t, err, "the connection should have failed due to empty account.") driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("Snowflake error is expected. err: %v", err.Error()) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 7a591d414..de8388e21 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -671,11 +671,7 @@ func TestUploadDownloadOneFileRequireCompressStream(t *testing.T) { } func testUploadDownloadOneFile(t *testing.T, isStream bool) { - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Fatalf("cannot create temp directory: %v", err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { @@ -749,8 +745,8 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) { if err != nil { t.Fatal(err) } + defer os.Remove("download.txt") if downloadMeta.resStatus != downloaded { t.Fatalf("failed to download file") } - defer os.Remove("download.txt") } diff --git a/transaction_test.go b/transaction_test.go index 0cb274be6..ccbf30fb3 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -97,34 +97,31 @@ func withRetry(fn func(context.Context, *sql.Conn) error, numAttempts int, timeo } func TestTransactionError(t *testing.T) { - var err error - var tx snowflakeTx - sr := &snowflakeRestful{ FuncPostQuery: postQueryFail, } - tx.sc = &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + + tx := snowflakeTx{ + sc: &snowflakeConn{ + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + }, + ctx: context.Background(), } - tx.ctx = context.Background() // test for post query error when executing the txCommand - err = tx.execTxCommand(rollback) - if err == nil { - t.Fatal("should have failed to post query") - } + err := tx.execTxCommand(rollback) + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "failed to get query response") // test for invalid txCommand err = tx.execTxCommand(2) - if err == nil { - t.Fatal("should have failed to execute unsupported transaction command") - } + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "unsupported transaction command") // test for bad connection error when snowflakeConn is nil tx.sc = nil err = tx.execTxCommand(rollback) - if err == nil { - t.Fatal("should have failed to execute txCommand when connection is nil") - } + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "driver: bad connection") } From ae6189556a7aaa970b56f1b9a2128e7a0c5d64d5 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Thu, 26 Oct 2023 23:32:48 -0700 Subject: [PATCH 18/19] change t.tempDir back to os.MkdirTemp --- file_transfer_agent_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index de8388e21..335bcdd98 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -671,7 +671,11 @@ func TestUploadDownloadOneFileRequireCompressStream(t *testing.T) { } func testUploadDownloadOneFile(t *testing.T, isStream bool) { - tmpDir := t.TempDir() + tmpDir, err := os.MkdirTemp("", "data") + if err != nil { + t.Fatalf("cannot create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { From c157c4f622d2e3967a50119bf8cd8d45cd260f87 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:17:44 -0700 Subject: [PATCH 19/19] use assertions in tests --- assert_test.go | 4 ++++ auth_test.go | 2 +- authokta_test.go | 26 ++++++++------------------ connection_util.go | 4 ---- connector_test.go | 10 ++++------ dsn_test.go | 8 ++++---- heartbeat_test.go | 40 +++++++++++++--------------------------- 7 files changed, 34 insertions(+), 60 deletions(-) diff --git a/assert_test.go b/assert_test.go index 185e2547e..d25217bd7 100644 --- a/assert_test.go +++ b/assert_test.go @@ -10,6 +10,10 @@ import ( "testing" ) +func assertNilE(t *testing.T, actual any, descriptions ...string) { + errorOnNonEmpty(t, validateNil(actual, descriptions...)) +} + func assertNilF(t *testing.T, actual any, descriptions ...string) { fatalOnNonEmpty(t, validateNil(actual, descriptions...)) } diff --git a/auth_test.go b/auth_test.go index 007a829af..4a6fd0e9f 100644 --- a/auth_test.go +++ b/auth_test.go @@ -678,7 +678,7 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) { sc.ctx = context.Background() err = authenticateWithConfig(sc) - assertNilF(t, err, "expected to have no error.") + assertNilE(t, err, "expected to have no error.") sr.FuncPostAuthSAML = postAuthSAMLError err = authenticateWithConfig(sc) diff --git a/authokta_test.go b/authokta_test.go index 8adb6466b..56e151215 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -246,22 +246,16 @@ func TestUnitAuthenticateBySAML(t *testing.T) { _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeIdpConnectionError { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) - } + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeIdpConnectionError { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) - } + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") @@ -297,10 +291,6 @@ func TestUnitAuthenticateBySAML(t *testing.T) { _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) assertNotNilF(t, err, "should have failed at FuncGetSSO.") driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeSSOURLNotMatch { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeSSOURLNotMatch, driverErr.Number) - } + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch) } diff --git a/connection_util.go b/connection_util.go index 384bc113a..4d37dea28 100644 --- a/connection_util.go +++ b/connection_util.go @@ -23,10 +23,6 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { return strings.Compare(*v, "true") == 0 } -func (sc *snowflakeConn) isHeartbeatNil() bool { - return sc.rest == nil || sc.rest.HeartBeat == nil -} - func (sc *snowflakeConn) startHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return diff --git a/connector_test.go b/connector_test.go index 1a795c767..76886199e 100644 --- a/connector_test.go +++ b/connector_test.go @@ -61,11 +61,9 @@ func TestConnectorWithMissingConfig(t *testing.T) { connector := NewConnector(&mock, config) _, err := connector.Connect(context.Background()) assertNotNilF(t, err, "the connection should have failed due to empty account.") + driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("Snowflake error is expected. err: %v", err.Error()) - } - if driverErr.Number != expectedErr.Number || driverErr.Message != expectedErr.Message { - t.Fatalf("Snowflake error did not match. expected: %v, got: %v", expectedErr, driverErr) - } + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, expectedErr.Number) + assertEqualE(t, driverErr.Message, expectedErr.Message) } diff --git a/dsn_test.go b/dsn_test.go index c92227201..9af66c90e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -706,10 +706,10 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, - { - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", - err: errFailedToParseAuthenticator(), - }, + { + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", + err: errFailedToParseAuthenticator(), + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { diff --git a/heartbeat_test.go b/heartbeat_test.go index 26aa862f4..291b8f847 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -20,34 +20,24 @@ func TestUnitPostHeartbeat(t *testing.T) { restful: sr, } err := heartbeat.heartbeatMain() - if err != nil { - t.Fatalf("failed to heartbeat and renew session. err: %v", err) - } + assertNilF(t, err, "failed to heartbeat and renew session") heartbeat.restful.FuncPost = postTestError err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed to start heartbeat") + assertEqualE(t, err.Error(), "failed to run post method") heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed to start heartbeat") + assertHasPrefixE(t, err.Error(), "invalid character") heartbeat.restful.FuncPost = postTestAppForbiddenError err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed to start heartbeat") driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToHeartbeat { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) - } + assertTrueF(t, ok, "connection should be snowflakeConn") + assertEqualE(t, driverErr.Number, ErrFailedToHeartbeat) }) } @@ -64,15 +54,11 @@ func TestHeartbeatStartAndStop(t *testing.T) { } conn, ok := db.(*snowflakeConn) - if ok && conn.isHeartbeatNil() { - t.Fatalf("heartbeat should not be nil") - } + assertTrueF(t, ok, "connection should be snowflakeConn") + assertNotNilF(t, conn.rest, "heartbeat should not be nil") + assertNotNilF(t, conn.rest.HeartBeat, "heartbeat should not be nil") err = db.Close() - if err != nil { - t.Fatalf("should not cause error in Close. err: %v", err) - } - if ok && !conn.isHeartbeatNil() { - t.Fatalf("heartbeat should be nil") - } + assertNilF(t, err, "should not cause error in Close") + assertNilF(t, conn.rest, "heartbeat should be nil") }