Skip to content

Commit

Permalink
refactor to add documentation, clarify variable names, add test cases…
Browse files Browse the repository at this point in the history
…, and better encapsulate behaviors (among other things to simplify testing)
  • Loading branch information
claytondaley committed Apr 29, 2019
1 parent 4cfb11f commit b0f599d
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 90 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
build
dist
*.egg-info
.idea
47 changes: 24 additions & 23 deletions ssm-diff
Original file line number Diff line number Diff line change
@@ -1,52 +1,53 @@
#!/usr/bin/env python
from __future__ import print_function
from states import *
import states.helpers as helpers

import argparse
import os

from states import states
from states.helpers import DiffResolver


def configure_endpoints(args):
# pre-configure resolver, but still accept remote and local at runtime
diff_resolver = DiffResolver.configure(force=args.force)
return states.RemoteState(args.profile, diff_resolver, paths=args.path), states.LocalState(args.filename, paths=args.path)


def init(args):
r, l = RemoteState(args.profile), LocalState(args.filename)
l.save(r.get(flat=False, paths=args.path))
"""Create a local YAML file from the SSM Parameter Store (per configs in args)"""
remote, local = configure_endpoints(args)
local.save(remote.clone())


def pull(args):
dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ])
r, l = RemoteState(args.profile), LocalState(args.filename)
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
if args.force:
ref_set = diff.changed().union(diff.removed()).union(diff.unchanged())
target_set = diff.added()
else:
ref_set = diff.unchanged().union(diff.removed())
target_set = diff.added().union(diff.changed())
state = dictfilter(diff.ref, ref_set)
state.update(dictfilter(diff.target, target_set))
l.save(helpers.unflatten(state))
"""Update local YAML file with changes in the SSM Parameter Store (per configs in args)"""
remote, local = configure_endpoints(args)
local.save(remote.pull(local.get()))


def apply(args):
r, _, diff = plan(args)

"""Apply local changes to the SSM Parameter Store"""
remote, local = configure_endpoints(args)
print("\nApplying changes...")
try:
r.apply(diff)
remote.push(local.get())
except Exception as e:
print("Failed to apply changes to remote:", e)
print("Done.")


def plan(args):
r, l = RemoteState(args.profile), LocalState(args.filename)
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
"""Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)"""
remote, local = configure_endpoints(args)
diff = remote.dry_run(local.get())

if diff.differ:
diff.print_state()
print(diff.describe_diff())
else:
print("Remote state is up to date.")

return r, l, diff
return remote, local, diff


if __name__ == "__main__":
Expand Down
116 changes: 76 additions & 40 deletions states/helpers.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,101 @@
from termcolor import colored
from copy import deepcopy
import collections
from copy import deepcopy
from functools import partial

from termcolor import colored

class FlatDictDiffer(object):
def __init__(self, ref, target):
self.ref, self.target = ref, target
self.ref_set, self.target_set = set(ref.keys()), set(target.keys())
self.isect = self.ref_set.intersection(self.target_set)

class DiffResolver(object):
"""Determines diffs between two dicts, where the remote copy is considered the baseline"""
def __init__(self, remote, local, force=False):
self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local)
self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys())
self.intersection = self.remote_set.intersection(self.local_set)
self.force = force

if self.added() or self.removed() or self.changed():
self.differ = True
else:
self.differ = False

@classmethod
def configure(cls, *args, **kwargs):
return partial(cls, *args, **kwargs)

def added(self):
return self.target_set - self.isect
"""Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}"""
return self.local_set - self.intersection

def removed(self):
return self.ref_set - self.isect
"""Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}"""
return self.remote_set - self.intersection

def changed(self):
return set(k for k in self.isect if self.ref[k] != self.target[k])
"""Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}"""
return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k])

def unchanged(self):
return set(k for k in self.isect if self.ref[k] == self.target[k])
"""Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}"""
return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k])

def print_state(self):
def describe_diff(self):
"""Return a (multi-line) string describing all differences"""
description = ""
for k in self.added():
print(colored("+", 'green'), "{} = {}".format(k, self.target[k]))
description += colored("+", 'green'), "{} = {}".format(k, self.local_flat[k]) + '\n'

for k in self.removed():
print(colored("-", 'red'), k)
description += colored("-", 'red'), k + '\n'

for k in self.changed():
print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k]))


def flatten(d, pkey='', sep='/'):
items = []
for k in d:
new = pkey + sep + k if pkey else k
if isinstance(d[k], collections.MutableMapping):
items.extend(flatten(d[k], new, sep=sep).items())
description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.remote_flat[k], self.local_flat[k]) + '\n'

return description

def _flatten(self, d, current_path='', sep='/'):
"""Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}"""
items = []
for k in d:
new = current_path + sep + k if current_path else k
if isinstance(d[k], collections.MutableMapping):
items.extend(self._flatten(d[k], new, sep=sep).items())
else:
items.append((sep + new, d[k]))
return dict(items)

def _unflatten(self, d, sep='/'):
"""Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure"""
output = {}
for k in d:
add(
obj=output,
path=k,
value=d[k],
sep=sep,
)
return output

def merge(self):
"""Generate a merge of the local and remote dicts, following configurations set during __init__"""
dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)])
if self.force:
# Overwrite local changes (i.e. only preserve added keys)
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
prior_set = self.changed().union(self.removed()).union(self.unchanged())
current_set = self.added()
else:
items.append((sep + new, d[k]))
return dict(items)


def add(obj, path, value):
parts = path.strip("/").split("/")
# Preserve added keys and changed keys
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
prior_set = self.unchanged().union(self.removed())
current_set = self.added().union(self.changed())
state = dictfilter(original=self.remote_flat, keep_keys=prior_set)
state.update(dictfilter(original=self.local_flat, keep_keys=current_set))
return self._unflatten(state)


def add(obj, path, value, sep='/'):
"""Add value to the `obj` dict at the specified path"""
parts = path.strip(sep).split(sep)
last = len(parts) - 1
for index, part in enumerate(parts):
if index == last:
Expand All @@ -61,7 +107,7 @@ def add(obj, path, value):
def search(state, path):
result = state
for p in path.strip("/").split("/"):
if result.get(p):
if result.clone(p):
result = result[p]
else:
result = {}
Expand All @@ -71,16 +117,6 @@ def search(state, path):
return output


def unflatten(d):
output = {}
for k in d:
add(
obj=output,
path=k,
value=d[k])
return output


def merge(a, b):
if not isinstance(b, dict):
return b
Expand Down
78 changes: 51 additions & 27 deletions states/states.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import print_function
from botocore.exceptions import ClientError, NoCredentialsError
from .helpers import flatten, merge, add, search

import sys
import os
import yaml

import boto3
import termcolor
import yaml
from botocore.exceptions import ClientError, NoCredentialsError

from .helpers import merge, add, search


def str_presenter(dumper, data):
if len(data.splitlines()) == 1 and data[-1] == '\n':
Expand All @@ -17,8 +20,10 @@ def str_presenter(dumper, data):
return dumper.represent_scalar(
'tag:yaml.org,2002:str', data.strip())


yaml.SafeDumper.add_representer(str, str_presenter)


class SecureTag(yaml.YAMLObject):
yaml_tag = u'!secure'

Expand All @@ -38,7 +43,7 @@ def __hash__(self):
return hash(self.secure)

def __ne__(self, other):
return (not self.__eq__(other))
return not self.__eq__(other)

@classmethod
def from_yaml(cls, loader, node):
Expand All @@ -50,25 +55,28 @@ def to_yaml(cls, dumper, data):
return dumper.represent_scalar(cls.yaml_tag, data.secure, style='|')
return dumper.represent_scalar(cls.yaml_tag, data.secure)


yaml.SafeLoader.add_constructor('!secure', SecureTag.from_yaml)
yaml.SafeDumper.add_multi_representer(SecureTag, SecureTag.to_yaml)


class LocalState(object):
def __init__(self, filename):
"""Encodes/decodes a dictionary to/from a YAML file"""
def __init__(self, filename, paths=('/',)):
self.filename = filename
self.paths = paths

def get(self, paths, flat=True):
def get(self):
try:
output = {}
with open(self.filename,'rb') as f:
l = yaml.safe_load(f.read())
for path in paths:
with open(self.filename, 'rb') as f:
local = yaml.safe_load(f.read())
for path in self.paths:
if path.strip('/'):
output = merge(output, search(l, path))
output = merge(output, search(local, path))
else:
return flatten(l) if flat else l
return flatten(output) if flat else output
return local
return output
except IOError as e:
print(e, file=sys.stderr)
if e.errno == 2:
Expand All @@ -90,53 +98,69 @@ def save(self, state):


class RemoteState(object):
def __init__(self, profile):
"""Encodes/decodes a dict to/from the SSM Parameter Store"""
def __init__(self, profile, diff_class, paths=('/',)):
if profile:
boto3.setup_default_session(profile_name=profile)
self.ssm = boto3.client('ssm')
self.diff_class = diff_class
self.paths = paths

def get(self, paths=['/'], flat=True):
def clone(self):
p = self.ssm.get_paginator('get_parameters_by_path')
output = {}
for path in paths:
for path in self.paths:
try:
for page in p.paginate(
Path=path,
Recursive=True,
WithDecryption=True):
Path=path,
Recursive=True,
WithDecryption=True):
for param in page['Parameters']:
add(obj=output,
path=param['Name'],
value=self._read_param(param['Value'], param['Type']))
except (ClientError, NoCredentialsError) as e:
print("Failed to fetch parameters from SSM!", e, file=sys.stderr)

return flatten(output) if flat else output
return output

# noinspection PyMethodMayBeStatic
def _read_param(self, value, ssm_type='String'):
return SecureTag(value) if ssm_type == 'SecureString' else str(value)

def apply(self, diff):
def pull(self, local):
diff = self.diff_class(
remote=self.clone(),
local=local,
force=self.force
)
return diff.merge()

def dry_run(self, local):
return self.diff_class(self.clone(), local, force=self.force)

def push(self, local):
diff = self.dry_run(local)

# diff.added|removed|changed return a "flattened" dict i.e. {"full/path": "value", ...}
for k in diff.added():
ssm_type = 'String'
if isinstance(diff.target[k], list):
if isinstance(diff.local[k], list):
ssm_type = 'StringList'
if isinstance(diff.target[k], SecureTag):
if isinstance(diff.local[k], SecureTag):
ssm_type = 'SecureString'
self.ssm.put_parameter(
Name=k,
Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]),
Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]),
Type=ssm_type)

for k in diff.removed():
self.ssm.delete_parameter(Name=k)

for k in diff.changed():
ssm_type = 'SecureString' if isinstance(diff.target[k], SecureTag) else 'String'

ssm_type = 'SecureString' if isinstance(diff.local[k], SecureTag) else 'String'
self.ssm.put_parameter(
Name=k,
Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]),
Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]),
Overwrite=True,
Type=ssm_type)
Loading

0 comments on commit b0f599d

Please sign in to comment.