Skip to content

Commit

Permalink
feature: localPolicyRetriever cache invalidation [#15]
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Van Bouwel committed Dec 17, 2024
1 parent 618c3ff commit fd4b2b6
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 3 deletions.
155 changes: 152 additions & 3 deletions cmd/policy-retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package cmd
import (
"bytes"
"fmt"
"log/slog"
"path/filepath"
"sync"
"text/template"

"github.com/fsnotify/fsnotify"
)

const PathSeparator = "/"
Expand All @@ -14,12 +17,54 @@ const policySuffix = ".json.tmpl"

type LocalPolicyRetriever struct{
rolePolicyPath string

//To communicate cache invalidation.
pm *PolicyManager

//To monitor file system changes
watcher *fsnotify.Watcher
}

func NewLocalPolicyRetriever(stsRolePolicyPath string) *LocalPolicyRetriever {
return &LocalPolicyRetriever{
var lp *LocalPolicyRetriever

var fileDeleted fileCallback = func(fileName string) {
if lp.pm == nil {
slog.Warn("There was no Policy Manager for local retriever to handle file deletion", "retriever", lp)
} else {
arn, err := lp.getPolicyArn(fileName)
if err != nil {
slog.Error("Could not get arn", "filename", fileName)
}
slog.Info("Remove policy", "arn", arn)
lp.pm.deletePolicyCacheEntry(arn)
}
}

var fileUpdated fileCallback = func(fileName string) {
if lp.pm == nil {
slog.Warn("There was no Policy Manager for local retriever to handle file update", "retriever", lp)
} else {
arn, err := lp.getPolicyArn(fileName)
if err != nil {
slog.Error("Could not get arn", "filename", fileName)
}
slog.Info("Reload policy", "arn", arn)
lp.pm.deletePolicyCacheEntry(arn)
_, err = lp.pm.getPolicyTemplate(arn)
if err != nil {
slog.Warn("Could not get policy", "policyArn", arn)
}
}
}

watcher := createFileWatcherAndStartWatching(fileUpdated, fileDeleted)
lp = &LocalPolicyRetriever{
rolePolicyPath: stsRolePolicyPath,
watcher: watcher,
}

return lp
}

func (r *LocalPolicyRetriever) getPolicyPathPrefix() (string) {
Expand All @@ -31,6 +76,22 @@ func (r *LocalPolicyRetriever) getPolicyPath(arn string) (string) {
return fmt.Sprintf("%s%s%s", r.getPolicyPathPrefix(), safeRoleArn, policySuffix)
}

func (r *LocalPolicyRetriever) getPolicyArn(filePath string) (string, error) {
prefix := r.getPolicyPathPrefix()
suffix := policySuffix

if len(suffix) > len(filePath) || len(prefix) > len(filePath) - len(suffix) {
slog.Warn("Invalid file path for policy", "filepath", filePath)
}

safePolicyName := filePath[len(prefix):len(filePath) - len(suffix)]
policyArn, err := b32_decode(safePolicyName)
if err != nil {
return "", err
}
return policyArn, nil
}

func (r LocalPolicyRetriever) retrieveAllIdentifiers() ([]string, error) {
prefix := r.getPolicyPathPrefix()
suffix := policySuffix
Expand All @@ -50,19 +111,31 @@ func (r LocalPolicyRetriever) retrieveAllIdentifiers() ([]string, error) {
}

func (r *LocalPolicyRetriever) retrievePolicyStr(arn string) (string, error) {
c, err := readFileFull(r.getPolicyPath(arn))
filePath := r.getPolicyPath(arn)
startWatching(r.watcher, filePath) // For cache invalidation
c, err := readFileFull(filePath)
if err != nil {
return "", err
}
return string(c), err
}

func (r *LocalPolicyRetriever) registerPolicyManager(pm *PolicyManager) {
r.pm = pm
}

type PolicyRetriever interface {
//Retrieve the policy content based out of an identifier which can be an AWS ARN
retrievePolicyStr(string) (string, error)

//Get all policy identifiers
retrieveAllIdentifiers() ([]string, error)

//Set PolicyManager
//Each policy retriever can be used by 1 policy Manager when the policy manager gets
//created with a policy retriever it will register itself using this method this allows
//The retriever to do calls to the policy manager for example to communicate policy changes
registerPolicyManager(pm *PolicyManager)
}

type PolicyManager struct {
Expand Down Expand Up @@ -176,10 +249,86 @@ func (m *PolicyManager) GetPolicy(arn string, data *PolicySessionData) (string,
return buf.String(), nil
}

func (m *PolicyManager) deletePolicyCacheEntry(arn string) {
m.tMux.Lock()
defer m.tMux.Unlock()
_, exists := m.templates[arn]
if !exists {
return
} else {
delete(m.templates, arn)
}
}

func NewPolicyManager(r PolicyRetriever) *PolicyManager{
return &PolicyManager{
pm := &PolicyManager{
retriever: r,
templates: map[string]*template.Template{},
tMux: &sync.RWMutex{},
}
r.registerPolicyManager(pm)
return pm
}

//A callback function that takes a filepath to action a change to a file.
type fileCallback func(string) ()


//Start a watcher to keep an eye on files
//
//This will start watching later on
func createFileWatcherAndStartWatching(fileChanged, fileDeleted fileCallback) (*fsnotify.Watcher) {
//See https://github.com/fsnotify/fsnotify
watcher, err := fsnotify.NewWatcher()
if err != nil {
slog.Error("Could not create new watcher", "error", err)
}

// Start listening for events.
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
slog.Debug("Config watcher event", "event", event)
if event.Has(fsnotify.Write) {
slog.Debug("Write notification", "event", event)
fileChanged(event.Name)
}
if event.Has(fsnotify.Remove) {
slog.Debug("Deletion notification", "event", event)
fileDeleted(event.Name)
// See https://ahmet.im/blog/kubernetes-inotify/
restartWatching(watcher, event.Name)
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
slog.Warn("error with file watcher", "error", err)
}
}
}()
return watcher
}

func startWatching(watcher *fsnotify.Watcher, fileName string) {
err := watcher.Add(fileName)
if err != nil {
slog.Error("Could not add watcher", "filename", fileName, "error", err)
} else {
slog.Debug("Started watching file", "filename", fileName)
}
}

func restartWatching(watcher *fsnotify.Watcher, fileName string) {
err := watcher.Remove(fileName)
if err != nil {
slog.Debug("Wanted to stop watching file but watcher was gone", "filename", fileName)
} else {
slog.Debug("Stopped watching file", "filename", fileName)
}
startWatching(watcher, fileName)
}
4 changes: 4 additions & 0 deletions cmd/policy-retrieval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func (r TestPolicyRetriever) retrievePolicyStr(arn string) (string, error) {
return policy, nil
}

func (r TestPolicyRetriever) registerPolicyManager(pm *PolicyManager) {
//Cache invalidation is not a thing for testpolicy retriever so no need to keep PolicyManager
}

func (r TestPolicyRetriever) retrieveAllIdentifiers() ([]string, error) {
keys := make([]string, len(r.testPolicies))

Expand Down

0 comments on commit fd4b2b6

Please sign in to comment.