diff --git a/cmd/policy-generation.go b/cmd/policy-generation.go deleted file mode 100644 index 671723d..0000000 --- a/cmd/policy-generation.go +++ /dev/null @@ -1,185 +0,0 @@ -package cmd - -import ( - "bytes" - "fmt" - "path/filepath" - "sync" - "text/template" -) - -const PathSeparator = "/" -const policySuffix = ".json.tmpl" - - -type LocalPolicyRetriever struct{ - rolePolicyPath string -} - -func NewLocalPolicyRetriever(stsRolePolicyPath string) *LocalPolicyRetriever { - return &LocalPolicyRetriever{ - rolePolicyPath: stsRolePolicyPath, - } -} - -func (r *LocalPolicyRetriever) getPolicyPathPrefix() (string) { - return fmt.Sprintf("%s%s", r.rolePolicyPath, PathSeparator) -} - -func (r *LocalPolicyRetriever) getPolicyPath(arn string) (string) { - safeRoleArn := b32(arn) - return fmt.Sprintf("%s%s%s", r.getPolicyPathPrefix(), safeRoleArn, policySuffix) -} - -func (r LocalPolicyRetriever) retrieveAllIdentifiers() ([]string, error) { - prefix := r.getPolicyPathPrefix() - suffix := policySuffix - matches, err := filepath.Glob(fmt.Sprintf("%s*%s", prefix , suffix)) - if err != nil { - return nil, err - } - cleanedMatches := make([]string, len(matches)) - for i, match := range matches { - safePolicyName := match[len(prefix):len(match) - len(suffix)] - cleanedMatches[i], err = b32_decode(safePolicyName) - if err != nil { - return nil, err - } - } - return cleanedMatches, err -} - -func (r *LocalPolicyRetriever) retrievePolicyStr(arn string) (string, error) { - c, err := readFileFull(r.getPolicyPath(arn)) - if err != nil { - return "", err - } - return string(c), err -} - -type PolicyRetriever interface { - //Retrieve the policy content based out of an identifier which can be an AWS ARN - retrievePolicyStr(string) (string, error) - - //Get all policy identifiers - retrieveAllIdentifiers() ([]string, error) -} - -type PolicyManager struct { - retriever PolicyRetriever - templates map[string]*template.Template - //Mutex for local template access - tMux *sync.RWMutex -} - -//Check if a policy manager can get a policy corresponding to an ARN -func (m *PolicyManager) DoesPolicyExist(arn string) bool { - - _, err := m.getPolicyTemplate(arn) - return err == nil -} - -//Check if a policy manager can get a policy corresponding to an ARN -func (m *PolicyManager) PreWarm() error { - ids, err := m.retriever.retrieveAllIdentifiers() - if err != nil { - return err - } - for _, policyId := range ids { - _, err := m.getPolicyTemplate(policyId) - if err != nil{ - return err - } - } - return nil -} - -//Get template from local cache and nil if it does not exist -func (m *PolicyManager) getPolicyTemplateFromCache(arn string) (tmpl *template.Template) { - m.tMux.RLock() - defer m.tMux.RUnlock() - tmpl, exists := m.templates[arn] - if !exists { - return nil - } - return tmpl -} - -func (m *PolicyManager) getPolicyTemplate(arn string) (tmpl *template.Template, err error) { - tmpl = m.getPolicyTemplateFromCache(arn) - if tmpl != nil { - return - } - policy, err := m.retriever.retrievePolicyStr(arn) - if err != nil { - return nil, err - - } - funcMap := template.FuncMap{ - "YYYYmmdd": YYYYmmdd, - "Now": Now, - "Add1Day": Add1Day, - "SHA1": sha1sum, - "YYYYmmddSlashed": YYYYmmddSlashed, - } - tmpl, err = template.New(arn).Funcs(funcMap).Parse(policy) - if err == nil { - m.tMux.Lock() - defer m.tMux.Unlock() - m.templates[arn] = tmpl - } else { - return nil, err - } - return -} - - -type PolicySessionClaims struct { - Subject string - Issuer string -} - - -//This is the structure that will be made available during templating and -//thus is available to be used in policies. -type PolicySessionData struct { - Claims PolicySessionClaims - Tags AWSSessionTags - RequestedRegion string -} - -func GetPolicySessionDataFromClaims(claims *SessionClaims) *PolicySessionData { - issuer := claims.IIssuer - if issuer == "" { - issuer = claims.Issuer - } - return &PolicySessionData{ - Claims: PolicySessionClaims{ - Subject: claims.Subject, - Issuer: issuer, - }, - Tags: claims.Tags, - } -} - - -func (m *PolicyManager) GetPolicy(arn string, data *PolicySessionData) (string, error) { - tmpl, err := m.getPolicyTemplate(arn) - if err != nil { - return "", err - } - buf := new(bytes.Buffer) - err = tmpl.Execute(buf, data) - if err != nil { - return "", err - } - return buf.String(), nil -} - -func NewPolicyManager(r PolicyRetriever) *PolicyManager{ - return &PolicyManager{ - retriever: r, - templates: map[string]*template.Template{}, - tMux: &sync.RWMutex{}, - } -} \ No newline at end of file diff --git a/cmd/policy-retrieval.go b/cmd/policy-retrieval.go new file mode 100644 index 0000000..3dd2482 --- /dev/null +++ b/cmd/policy-retrieval.go @@ -0,0 +1,334 @@ +package cmd + +import ( + "bytes" + "fmt" + "log/slog" + "path/filepath" + "sync" + "text/template" + + "github.com/fsnotify/fsnotify" +) + +const PathSeparator = "/" +const policySuffix = ".json.tmpl" + + +type LocalPolicyRetriever struct{ + rolePolicyPath string + + //To communicate cache invalidation. + pm *PolicyManager + + //To monitor file system changes + watcher *fsnotify.Watcher +} + +func NewLocalPolicyRetriever(stsRolePolicyPath string) *LocalPolicyRetriever { + var lp *LocalPolicyRetriever + + var fileDeleted fileCallback = func(fileName string) { + if lp.pm == nil { + slog.Warn("There was no Policy Manager for local retriever to handle file deletion", "retriever", lp) + } else { + arn, err := lp.getPolicyArn(fileName) + if err != nil { + slog.Error("Could not get arn", "filename", fileName) + } + slog.Info("Remove policy", "arn", arn) + lp.pm.deletePolicyCacheEntry(arn) + } + } + + var fileUpdated fileCallback = func(fileName string) { + if lp.pm == nil { + slog.Warn("There was no Policy Manager for local retriever to handle file update", "retriever", lp) + } else { + arn, err := lp.getPolicyArn(fileName) + if err != nil { + slog.Error("Could not get arn", "filename", fileName) + } + slog.Info("Reload policy", "arn", arn) + lp.pm.deletePolicyCacheEntry(arn) + _, err = lp.pm.getPolicyTemplate(arn) + if err != nil { + slog.Warn("Could not get policy", "policyArn", arn) + } + } + } + + watcher := createFileWatcherAndStartWatching(fileUpdated, fileDeleted) + lp = &LocalPolicyRetriever{ + rolePolicyPath: stsRolePolicyPath, + watcher: watcher, + } + + return lp +} + +func (r *LocalPolicyRetriever) getPolicyPathPrefix() (string) { + return fmt.Sprintf("%s%s", r.rolePolicyPath, PathSeparator) +} + +func (r *LocalPolicyRetriever) getPolicyPath(arn string) (string) { + safeRoleArn := b32(arn) + return fmt.Sprintf("%s%s%s", r.getPolicyPathPrefix(), safeRoleArn, policySuffix) +} + +func (r *LocalPolicyRetriever) getPolicyArn(filePath string) (string, error) { + prefix := r.getPolicyPathPrefix() + suffix := policySuffix + + if len(suffix) > len(filePath) || len(prefix) > len(filePath) - len(suffix) { + slog.Warn("Invalid file path for policy", "filepath", filePath) + } + + safePolicyName := filePath[len(prefix):len(filePath) - len(suffix)] + policyArn, err := b32_decode(safePolicyName) + if err != nil { + return "", err + } + return policyArn, nil +} + +func (r LocalPolicyRetriever) retrieveAllIdentifiers() ([]string, error) { + prefix := r.getPolicyPathPrefix() + suffix := policySuffix + matches, err := filepath.Glob(fmt.Sprintf("%s*%s", prefix , suffix)) + if err != nil { + return nil, err + } + cleanedMatches := make([]string, len(matches)) + for i, match := range matches { + safePolicyName := match[len(prefix):len(match) - len(suffix)] + cleanedMatches[i], err = b32_decode(safePolicyName) + if err != nil { + return nil, err + } + } + return cleanedMatches, err +} + +func (r *LocalPolicyRetriever) retrievePolicyStr(arn string) (string, error) { + filePath := r.getPolicyPath(arn) + startWatching(r.watcher, filePath) // For cache invalidation + c, err := readFileFull(filePath) + if err != nil { + return "", err + } + return string(c), err +} + +func (r *LocalPolicyRetriever) registerPolicyManager(pm *PolicyManager) { + r.pm = pm +} + +type PolicyRetriever interface { + //Retrieve the policy content based out of an identifier which can be an AWS ARN + retrievePolicyStr(string) (string, error) + + //Get all policy identifiers + retrieveAllIdentifiers() ([]string, error) + + //Set PolicyManager + //Each policy retriever can be used by 1 policy Manager when the policy manager gets + //created with a policy retriever it will register itself using this method this allows + //The retriever to do calls to the policy manager for example to communicate policy changes + registerPolicyManager(pm *PolicyManager) +} + +type PolicyManager struct { + retriever PolicyRetriever + templates map[string]*template.Template + //Mutex for local template access + tMux *sync.RWMutex +} + +//Check if a policy manager can get a policy corresponding to an ARN +func (m *PolicyManager) DoesPolicyExist(arn string) bool { + + _, err := m.getPolicyTemplate(arn) + return err == nil +} + +//Check if a policy manager can get a policy corresponding to an ARN +func (m *PolicyManager) PreWarm() error { + ids, err := m.retriever.retrieveAllIdentifiers() + if err != nil { + return err + } + for _, policyId := range ids { + _, err := m.getPolicyTemplate(policyId) + if err != nil{ + return err + } + } + return nil +} + +//Get template from local cache and nil if it does not exist +func (m *PolicyManager) getPolicyTemplateFromCache(arn string) (tmpl *template.Template) { + m.tMux.RLock() + defer m.tMux.RUnlock() + tmpl, exists := m.templates[arn] + if !exists { + return nil + } + return tmpl +} + +func (m *PolicyManager) getPolicyTemplate(arn string) (tmpl *template.Template, err error) { + tmpl = m.getPolicyTemplateFromCache(arn) + if tmpl != nil { + return + } + policy, err := m.retriever.retrievePolicyStr(arn) + if err != nil { + return nil, err + + } + funcMap := template.FuncMap{ + "YYYYmmdd": YYYYmmdd, + "Now": Now, + "Add1Day": Add1Day, + "SHA1": sha1sum, + "YYYYmmddSlashed": YYYYmmddSlashed, + } + tmpl, err = template.New(arn).Funcs(funcMap).Parse(policy) + if err == nil { + m.tMux.Lock() + defer m.tMux.Unlock() + m.templates[arn] = tmpl + } else { + return nil, err + } + return +} + + +type PolicySessionClaims struct { + Subject string + Issuer string +} + + +//This is the structure that will be made available during templating and +//thus is available to be used in policies. +type PolicySessionData struct { + Claims PolicySessionClaims + Tags AWSSessionTags + RequestedRegion string +} + +func GetPolicySessionDataFromClaims(claims *SessionClaims) *PolicySessionData { + issuer := claims.IIssuer + if issuer == "" { + issuer = claims.Issuer + } + return &PolicySessionData{ + Claims: PolicySessionClaims{ + Subject: claims.Subject, + Issuer: issuer, + }, + Tags: claims.Tags, + } +} + + +func (m *PolicyManager) GetPolicy(arn string, data *PolicySessionData) (string, error) { + tmpl, err := m.getPolicyTemplate(arn) + if err != nil { + return "", err + } + buf := new(bytes.Buffer) + err = tmpl.Execute(buf, data) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func (m *PolicyManager) deletePolicyCacheEntry(arn string) { + m.tMux.Lock() + defer m.tMux.Unlock() + _, exists := m.templates[arn] + if !exists { + return + } else { + delete(m.templates, arn) + } +} + +func NewPolicyManager(r PolicyRetriever) *PolicyManager{ + pm := &PolicyManager{ + retriever: r, + templates: map[string]*template.Template{}, + tMux: &sync.RWMutex{}, + } + r.registerPolicyManager(pm) + return pm +} + +//A callback function that takes a filepath to action a change to a file. +type fileCallback func(string) () + + +//Start a watcher to keep an eye on files +// +//This will start watching later on +func createFileWatcherAndStartWatching(fileChanged, fileDeleted fileCallback) (*fsnotify.Watcher) { + //See https://github.com/fsnotify/fsnotify + watcher, err := fsnotify.NewWatcher() + if err != nil { + slog.Error("Could not create new watcher", "error", err) + } + + // Start listening for events. + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + slog.Debug("Config watcher event", "event", event) + if event.Has(fsnotify.Write) { + slog.Debug("Write notification", "event", event) + fileChanged(event.Name) + } + if event.Has(fsnotify.Remove) { + slog.Debug("Deletion notification", "event", event) + fileDeleted(event.Name) + // See https://ahmet.im/blog/kubernetes-inotify/ + restartWatching(watcher, event.Name) + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + slog.Warn("error with file watcher", "error", err) + } + } + }() + return watcher +} + +func startWatching(watcher *fsnotify.Watcher, fileName string) { + err := watcher.Add(fileName) + if err != nil { + slog.Error("Could not add watcher", "filename", fileName, "error", err) + } else { + slog.Debug("Started watching file", "filename", fileName) + } +} + +func restartWatching(watcher *fsnotify.Watcher, fileName string) { + err := watcher.Remove(fileName) + if err != nil { + slog.Debug("Wanted to stop watching file but watcher was gone", "filename", fileName) + } else { + slog.Debug("Stopped watching file", "filename", fileName) + } + startWatching(watcher, fileName) +} \ No newline at end of file diff --git a/cmd/policy-generation_test.go b/cmd/policy-retrieval_test.go similarity index 50% rename from cmd/policy-generation_test.go rename to cmd/policy-retrieval_test.go index f872a05..6c0cfff 100644 --- a/cmd/policy-generation_test.go +++ b/cmd/policy-retrieval_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "os" "strings" "testing" "time" @@ -56,6 +57,10 @@ func (r TestPolicyRetriever) retrievePolicyStr(arn string) (string, error) { return policy, nil } +func (r TestPolicyRetriever) registerPolicyManager(pm *PolicyManager) { + //Cache invalidation is not a thing for testpolicy retriever so no need to keep PolicyManager +} + func (r TestPolicyRetriever) retrieveAllIdentifiers() ([]string, error) { keys := make([]string, len(r.testPolicies)) @@ -172,4 +177,103 @@ func TestPolicyManagerPrewarm(t *testing.T) { if !pm.DoesPolicyExist(expectedPolicy) { t.Errorf("Missing policy %s", expectedPolicy) } -} \ No newline at end of file +} + + +func createTestPolicyFileForLocalPolicyRetriever(policyArn, policyContent string, pr *LocalPolicyRetriever, t *testing.T) { + policyFileName := pr.getPolicyPath(policyArn) + f, err := os.Create(policyFileName) + checkErrorTestDependency(err, t, fmt.Sprintf("Could Not create policy file %s", policyFileName)) + + _, err = f.Write([]byte(policyContent)) + checkErrorTestDependency(err, t, fmt.Sprintf("Could not write policy content while creating test policy %s: %s", policyArn, policyContent)) + + defer f.Close() +} + +func deleteTestPolicyFileForLocalPolicyRetriever(policyArn string, pr *LocalPolicyRetriever, t *testing.T) { + policyFileName := pr.getPolicyPath(policyArn) + err := os.Remove(policyFileName) + checkErrorTestDependency(err, t, fmt.Sprintf("Could not delete policy file %s", policyFileName)) +} + + +func TestCacheInvalidationLocalPolicyRetrieverIfPolicyIsRemoved(t *testing.T) { + //Given a policy manager that is backed by a local PolicyRetriever + pr := NewLocalPolicyRetriever(t.TempDir()) + pm := NewPolicyManager(pr) + //Given a test Arn + testArn := "arn:aws:iam::000000000000:role/cache-invalidation" + + //WHEN the policy file exists in the expected place + createTestPolicyFileForLocalPolicyRetriever(testArn, testPolicyAllowAll, pr, t) + //THEN it must exist as per the Policy Manager + if !pm.DoesPolicyExist(testArn) { + t.Errorf("Policy %s should have existed but it did not", testArn) + t.FailNow() + } + + //WHEN the policyFile gets deleted + deleteTestPolicyFileForLocalPolicyRetriever(testArn, pr, t) + deletionTime := time.Now() + + var policyManagerKnowsPolicyDoesNotExist predicateFunction = func () bool{ + return !pm.DoesPolicyExist(testArn) + } + + //THEN in due time it should no longer exist + if !isTrueWithinDueTime(policyManagerKnowsPolicyDoesNotExist) { + t.Errorf("Policy %s was removed at %s and now %s policy manager still thinks it exists", testArn, deletionTime, time.Now()) + t.FailNow() + } +} + + +func TestCacheInvalidationLocalPolicyRetrieverIfPolicyIsChanged(t *testing.T) { + //Given a policy manager that is backed by a local PolicyRetriever + pr := NewLocalPolicyRetriever(t.TempDir()) + pm := NewPolicyManager(pr) + //Given 2 test Arn + testArn1 := "arn:aws:iam::000000000000:role/cache-invalidation1" + testArn2 := "arn:aws:iam::000000000000:role/cache-invalidation2" + + //WHEN the policy files exists in the expected place and are policies without time conditions + createTestPolicyFileForLocalPolicyRetriever(testArn1, testPolicyAllowAll, pr, t) + createTestPolicyFileForLocalPolicyRetriever(testArn2, testPolicyAllowAllInRegion1, pr, t) + + //THEN the templates actually differ + pol1, err1 := pm.GetPolicy(testArn1, &PolicySessionData{}) + checkErrorTestDependency(err1, t, "Policy1 should have been retrievable") + pol2, err2 := pm.GetPolicy(testArn2, &PolicySessionData{}) + checkErrorTestDependency(err2, t, "Policy2 should have been retrievable") + + if pol1 == pol2 { + t.Errorf("Policies should have been different but both gave: %s", pol1) + t.FailNow() + } + + //WHEN the 2nd policy gets updated such that it has the same content as the first. + deleteTestPolicyFileForLocalPolicyRetriever(testArn2, pr, t) + createTestPolicyFileForLocalPolicyRetriever(testArn2, testPolicyAllowAll, pr, t) + + updateTime := time.Now() + + var policyManagerSeesUpdate predicateFunction = func () bool{ + pol1, err1 := pm.GetPolicy(testArn1, &PolicySessionData{}) + checkErrorTestDependency(err1, t, "Policy1 should have been retrievable") + pol2, err2 := pm.GetPolicy(testArn2, &PolicySessionData{}) + checkErrorTestDependency(err2, t, "Policy2 should have been retrievable") + + return pol1 == pol2 + } + + //THEN in due time it should no longer exist + if !isTrueWithinDueTime(policyManagerSeesUpdate) { + polText, err := pm.GetPolicy(testArn2, &PolicySessionData{}) + if err != nil { + polText = err.Error() + } + t.Errorf("Policy %s was updated at %s and now %s policy manager still sees %s", testArn2, updateTime, time.Now(), polText) + t.FailNow() + } +} diff --git a/cmd/test-utils.go b/cmd/test-utils.go index add48af..eb94cd2 100644 --- a/cmd/test-utils.go +++ b/cmd/test-utils.go @@ -3,7 +3,9 @@ package cmd import ( "encoding/json" "os" + "strings" "testing" + "time" ) @@ -26,4 +28,48 @@ func skipIfNoTestingBackends(t *testing.T) { if os.Getenv("NO_TESTING_BACKENDS") != "" { t.Skip("Skipping this test because no testing backends and that is a dependency for thist test.") } +} + +//checkErrorTestDependency check for errors to pracitce safe programming but where you do not really +//expect problems (but cannot guarantee them not happening e.g. because of the execution environment). +//This is only to be used in test cases and will fail the test you can use msg to pass extra context info +func checkErrorTestDependency(err error, t *testing.T, msg ...string) { + var strMsg string + if len(msg) > 0 { + strMsg = strings.Join(msg, ", ") + } + if err != nil { + t.Errorf("Encountered error %s which should not occure. %s", err, strMsg) + t.FailNow() + } +} + +type predicateFunction func() bool + +//isTrueWithinDueTime takes a function that takes no arguments but returns a boolean +//and will await for maximum the first waitTime (which defaults to 5 seconds) and will +//check every second waitTime (defaults to 10 milliseconds) +func isTrueWithinDueTime(callable predicateFunction, waitTimes ...time.Duration) bool { + var maxWaitTime time.Duration = 5 * time.Second + var waitTimeBetweenChecks time.Duration = 10 * time.Millisecond + + if len(waitTimes) > 0 { + maxWaitTime = waitTimes[0] + } + giveUpTime := time.Now().Add(maxWaitTime) + + if len(waitTimes) > 1 { + waitTimeBetweenChecks = waitTimes[1] + } + + for { //infinite loop + if callable() { + return true + } + if time.Now().After(giveUpTime) { + return false // time to give up + } + time.Sleep(waitTimeBetweenChecks) + } + } \ No newline at end of file