diff --git a/.gitignore b/.gitignore index e2aa8a6..eeff2a0 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ .vscode .idea test_config.json +.DS_Store diff --git a/server/server.go b/server/server.go index 501b0d6..84fb1ed 100644 --- a/server/server.go +++ b/server/server.go @@ -8,11 +8,14 @@ import ( "io" "io/ioutil" "log" + "math" "mime/multipart" "net/http" "net/url" + "os" "regexp" "strings" + "time" ) const ( @@ -40,6 +43,11 @@ type Server struct { Configuration } +type TokenCache struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + // New returns an initialized Secrets object func New(config Configuration) (*Server, error) { if config.ServerURL == "" && config.Tenant == "" || config.ServerURL != "" && config.Tenant != "" { @@ -158,7 +166,13 @@ func (s Server) accessResource(method, resource, path string, input interface{}) log.Printf("[DEBUG] calling %s %s", method, req.URL.String()) - data, _, err := handleResponse((&http.Client{}).Do(req)) + data, statusCode, err := handleResponse((&http.Client{}).Do(req)) + + // Check for unauthorized or access denied + if statusCode.StatusCode == http.StatusUnauthorized || statusCode.StatusCode == http.StatusForbidden { + s.clearTokenCache() + log.Printf("[ERROR] Token cache cleared due to unauthorized or access denied response.") + } return data, err } @@ -252,17 +266,69 @@ func (s Server) uploadFile(secretId int, fileField SecretField) error { return err } +func (s *Server) setCacheAccessToken(value string, expiresIn int, baseURL string) error { + cache := TokenCache{} + cache.AccessToken = value + cache.ExpiresIn = (int(time.Now().Unix()) + expiresIn) - int(math.Floor(float64(expiresIn)*0.9)) + + data, _ := json.Marshal(cache) + os.Setenv("SS_AT_"+url.QueryEscape(baseURL), string(data)) + return nil +} + +func (s *Server) getCacheAccessToken(baseURL string) (string, bool) { + data, ok := os.LookupEnv("SS_AT_" + url.QueryEscape(baseURL)) + if !ok { + s.clearTokenCache() + return "", ok + } + cache := TokenCache{} + if err := json.Unmarshal([]byte(data), &cache); err != nil { + return "", false + } + if time.Now().Unix() < int64(cache.ExpiresIn) { + return cache.AccessToken, true + } + return "", false +} + +func (s *Server) clearTokenCache() { + var baseURL string + + if s.ServerURL == "" { + baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD) + } else { + baseURL = s.ServerURL + } + + os.Setenv("SS_AT_"+url.QueryEscape(baseURL), "") +} + // getAccessToken gets an OAuth2 Access Grant and returns the token // endpoint and get an accessGrant. func (s *Server) getAccessToken() (string, error) { if s.Credentials.Token != "" { return s.Credentials.Token, nil } - response, err := s.checkPlatformDetails() + var baseURL string + + if s.ServerURL == "" { + baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD) + } else { + baseURL = s.ServerURL + } + + response, err := s.checkPlatformDetails(baseURL) if err != nil { log.Print("Error while checking server details:", err) return "", err } else if err == nil && response == "" { + + accessToken, found := s.getCacheAccessToken(baseURL) + if found { + return accessToken, nil + } + values := url.Values{ "username": {s.Credentials.Username}, "password": {s.Credentials.Password}, @@ -292,21 +358,17 @@ func (s *Server) getAccessToken() (string, error) { log.Print("[ERROR] parsing grant response:", err) return "", err } + if err = s.setCacheAccessToken(grant.AccessToken, grant.ExpiresIn, baseURL); err != nil { + log.Print("[ERROR] caching access token:", err) + return "", err + } return grant.AccessToken, nil } else { return response, nil } } -func (s *Server) checkPlatformDetails() (string, error) { - var baseURL string - - if s.ServerURL == "" { - baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD) - } else { - baseURL = s.ServerURL - } - +func (s *Server) checkPlatformDetails(baseURL string) (string, error) { platformHelthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "health") ssHealthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "healthcheck.aspx") @@ -316,40 +378,50 @@ func (s *Server) checkPlatformDetails() (string, error) { } else { isHealthy := checkJSONResponse(platformHelthCheckUrl) if isHealthy { - requestData := url.Values{} - requestData.Set("grant_type", "client_credentials") - requestData.Set("client_id", s.Credentials.Username) - requestData.Set("client_secret", s.Credentials.Password) - requestData.Set("scope", "xpmheadless") - req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode())) - if err != nil { - log.Print("Error creating HTTP request:", err) - return "", err - } + accessToken, found := s.getCacheAccessToken(baseURL) + if !found { + requestData := url.Values{} + requestData.Set("grant_type", "client_credentials") + requestData.Set("client_id", s.Credentials.Username) + requestData.Set("client_secret", s.Credentials.Password) + requestData.Set("scope", "xpmheadless") + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode())) + if err != nil { + log.Print("Error creating HTTP request:", err) + return "", err + } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - data, _, err := handleResponse((&http.Client{}).Do(req)) - if err != nil { - log.Print("[ERROR] get token response error:", err) - return "", err - } + data, _, err := handleResponse((&http.Client{}).Do(req)) + if err != nil { + log.Print("[ERROR] get token response error:", err) + return "", err + } - var tokenjsonResponse OAuthTokens - if err = json.Unmarshal(data, &tokenjsonResponse); err != nil { - log.Print("[ERROR] parsing get token response:", err) - return "", err + var tokenjsonResponse OAuthTokens + if err = json.Unmarshal(data, &tokenjsonResponse); err != nil { + log.Print("[ERROR] parsing get token response:", err) + return "", err + } + accessToken = tokenjsonResponse.AccessToken + + if err = s.setCacheAccessToken(tokenjsonResponse.AccessToken, tokenjsonResponse.ExpiresIn, baseURL); err != nil { + log.Print("[ERROR] caching access token:", err) + return "", err + } } - req, err = http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{})) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{})) if err != nil { log.Print("Error creating HTTP request:", err) return "", err } - req.Header.Add("Authorization", "Bearer "+tokenjsonResponse.AccessToken) + req.Header.Add("Authorization", "Bearer "+accessToken) - data, _, err = handleResponse((&http.Client{}).Do(req)) + data, _, err := handleResponse((&http.Client{}).Do(req)) if err != nil { log.Print("[ERROR] get vaults response error:", err) return "", err @@ -374,7 +446,7 @@ func (s *Server) checkPlatformDetails() (string, error) { return "", fmt.Errorf("no configured vault found") } - return tokenjsonResponse.AccessToken, nil + return accessToken, nil } } return "", fmt.Errorf("invalid URL")