diff --git a/arrow_test.go b/arrow_test.go index cca07603d..e4d103c78 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -13,7 +13,7 @@ import ( "time" ) -//A test just to show Snowflake version +// A test just to show Snowflake version func TestCheckVersion(t *testing.T) { conn := openConn(t) defer conn.Close() diff --git a/connection_util.go b/connection_util.go index 737a025dd..4d37dea28 100644 --- a/connection_util.go +++ b/connection_util.go @@ -281,11 +281,13 @@ func populateChunkDownloader( func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host) + logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil { return err } - ocspRetryHost := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" - if err := os.Setenv(ocspRetryURLEnv, ocspRetryHost); err != nil { + ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" + logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) + if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil { return err } return nil diff --git a/data1.txt.gz b/data1.txt.gz new file mode 100644 index 000000000..e69de29bb diff --git a/driver_ocsp_test.go b/driver_ocsp_test.go index 9070c4c88..8ce109a3b 100644 --- a/driver_ocsp_test.go +++ b/driver_ocsp_test.go @@ -49,6 +49,7 @@ func cleanup() { unsetenv(ocspTestResponderTimeoutEnv) unsetenv(ocspTestResponderURLEnv) unsetenv(ocspTestNoOCSPURLEnv) + unsetenv(ocspRetryURLEnv) unsetenv(cacheDirEnv) } diff --git a/ocsp.go b/ocsp.go index 037f1034b..c809a24c6 100644 --- a/ocsp.go +++ b/ocsp.go @@ -360,14 +360,14 @@ func checkOCSPCacheServer( headers := make(map[string]string) res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider).execute() if err != nil { - logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v\n", err) + logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) return nil, &ocspStatus{ code: ocspFailedSubmit, err: err, } } defer res.Body.Close() - logger.Debugf("StatusCode from OCSP Cache Server: %v\n", res.StatusCode) + logger.Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode) if res.StatusCode != http.StatusOK { return nil, &ocspStatus{ code: ocspFailedResponse, @@ -381,7 +381,7 @@ func checkOCSPCacheServer( if err := dec.Decode(&respd); err == io.EOF { break } else if err != nil { - logger.Errorf("failed to decode OCSP cache. %v\n", err) + logger.Errorf("failed to decode OCSP cache. %v", err) return nil, &ocspStatus{ code: ocspFailedExtractResponse, err: err, @@ -428,7 +428,6 @@ func retryOCSP( err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), } } - logger.Debug("reading contents") ocspResBytes, err = io.ReadAll(res.Body) if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ @@ -436,7 +435,59 @@ func retryOCSP( err: err, } } - logger.Debug("parsing OCSP response") + ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) + if err != nil { + logger.Warnf("error when parsing ocsp response: %v", err) + logger.Warnf("performing GET fallback request to OCSP") + return fallbackRetryOCSPToGETRequest(ctx, client, req, ocspHost, headers, issuer, totalTimeout) + } + + logger.Debugf("OCSP Status from server: %v", printStatus(ocspRes)) + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspSuccess, + } +} + +// fallbackRetryOCSPToGETRequest is the third level of retry method. Some OCSP responders do not support POST requests +// and will return with a "malformed" request error. In that case we also try to perform a GET request +func fallbackRetryOCSPToGETRequest( + ctx context.Context, + client clientInterface, + req requestFunc, + ocspHost *url.URL, + headers map[string]string, + issuer *x509.Certificate, + totalTimeout time.Duration) ( + ocspRes *ocsp.Response, + ocspResBytes []byte, + ocspS *ocspStatus) { + multiplier := 1 + if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { + multiplier = 3 // up to 3 times for Fail Close mode + } + res, err := newRetryHTTP(ctx, client, req, ocspHost, headers, + totalTimeout*time.Duration(multiplier), defaultTimeProvider).execute() + if err != nil { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedSubmit, + err: err, + } + } + defer res.Body.Close() + logger.Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode) + if res.StatusCode != http.StatusOK { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedResponse, + err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), + } + } + ocspResBytes, err = io.ReadAll(res.Body) + if err != nil { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedExtractResponse, + err: err, + } + } ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ @@ -445,14 +496,39 @@ func retryOCSP( } } + logger.Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes)) return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, } } +func printStatus(response *ocsp.Response) string { + switch response.Status { + case ocsp.Good: + return "Good" + case ocsp.Revoked: + return "Revoked" + case ocsp.Unknown: + return "Unknown" + default: + return fmt.Sprintf("%d", response.Status) + } +} + +func fullOCSPURL(url *url.URL) string { + fullURL := url.Hostname() + if url.Path != "" { + if !strings.HasPrefix(url.Path, "/") { + fullURL += "/" + } + fullURL += url.Path + } + return fullURL +} + // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) *ocspStatus { - logger.Infof("Subject: %v, Issuer: %v\n", subject.Subject, issuer.Subject) + logger.Infof("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject) status, ocspReq, encodedCertID := validateWithCache(subject, issuer) if isValidOCSPStatus(status.code) { @@ -461,8 +537,8 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) if ocspReq == nil || encodedCertID == nil { return status } - logger.Infof("cache missed\n") - logger.Infof("OCSP Server: %v\n", subject.OCSPServer) + logger.Infof("cache missed") + logger.Infof("OCSP Server: %v", subject.OCSPServer) if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() { return &ocspStatus{ code: ocspNoServer, @@ -484,9 +560,14 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) hostnameStr := os.Getenv(ocspTestResponderURLEnv) var hostname string if retryURL := os.Getenv(ocspRetryURLEnv); retryURL != "" { - hostname = fmt.Sprintf(retryURL, u.Hostname(), base64.StdEncoding.EncodeToString(ocspReq)) + hostname = fmt.Sprintf(retryURL, fullOCSPURL(u), base64.StdEncoding.EncodeToString(ocspReq)) + u0, err := url.Parse(hostname) + if err == nil { + hostname = u0.Hostname() + u = u0 + } } else { - hostname = u.Hostname() + hostname = fullOCSPURL(u) } if hostnameStr != "" { u0, err := url.Parse(hostnameStr) @@ -495,6 +576,10 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) u = u0 } } + + logger.Debugf("Fetching OCSP response from server: %v", u) + logger.Debugf("Host in headers: %v", hostname) + headers := make(map[string]string) headers[httpHeaderContentType] = "application/ocsp-request" headers[httpHeaderAccept] = "application/ocsp-response" diff --git a/ocsp_test.go b/ocsp_test.go index 257e87978..e2deb515f 100644 --- a/ocsp_test.go +++ b/ocsp_test.go @@ -340,6 +340,44 @@ func TestOCSPRetry(t *testing.T) { } } +func TestFullOCSPURL(t *testing.T) { + testcases := []tcFullOCSPURL{ + { + url: &url.URL{Host: "some-ocsp-url.com"}, + expectedURLString: "some-ocsp-url.com", + }, + { + url: &url.URL{ + Host: "some-ocsp-url.com", + Path: "/some-path", + }, + expectedURLString: "some-ocsp-url.com/some-path", + }, + { + url: &url.URL{ + Host: "some-ocsp-url.com", + Path: "some-path", + }, + expectedURLString: "some-ocsp-url.com/some-path", + }, + } + + for _, testcase := range testcases { + t.Run("", func(t *testing.T) { + returnedStringURL := fullOCSPURL(testcase.url) + if returnedStringURL != testcase.expectedURLString { + t.Fatalf("failed to match returned OCSP url string; expected: %v, got: %v", + testcase.expectedURLString, returnedStringURL) + } + }) + } +} + +type tcFullOCSPURL struct { + url *url.URL + expectedURLString string +} + func TestOCSPCacheServerRetry(t *testing.T) { dummyOCSPHost := &url.URL{ Scheme: "https",