From b0f599df4f031e497a0b654b97ba8acfb70547bc Mon Sep 17 00:00:00 2001 From: Clayton Daley Date: Mon, 29 Apr 2019 09:52:35 -0500 Subject: [PATCH] refactor to add documentation, clarify variable names, add test cases, and better encapsulate behaviors (among other things to simplify testing) --- .gitignore | 1 + ssm-diff | 47 ++++++++------- states/helpers.py | 116 ++++++++++++++++++++++------------- states/states.py | 78 +++++++++++++++--------- states/tests.py | 150 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 90 deletions(-) create mode 100644 states/tests.py diff --git a/.gitignore b/.gitignore index fd133da..5d4635a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ build dist *.egg-info +.idea \ No newline at end of file diff --git a/ssm-diff b/ssm-diff index 73ebca1..ffad441 100755 --- a/ssm-diff +++ b/ssm-diff @@ -1,52 +1,53 @@ #!/usr/bin/env python from __future__ import print_function -from states import * -import states.helpers as helpers + import argparse import os +from states import states +from states.helpers import DiffResolver + + +def configure_endpoints(args): + # pre-configure resolver, but still accept remote and local at runtime + diff_resolver = DiffResolver.configure(force=args.force) + return states.RemoteState(args.profile, diff_resolver, paths=args.path), states.LocalState(args.filename, paths=args.path) + def init(args): - r, l = RemoteState(args.profile), LocalState(args.filename) - l.save(r.get(flat=False, paths=args.path)) + """Create a local YAML file from the SSM Parameter Store (per configs in args)""" + remote, local = configure_endpoints(args) + local.save(remote.clone()) def pull(args): - dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ]) - r, l = RemoteState(args.profile), LocalState(args.filename) - diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path)) - if args.force: - ref_set = diff.changed().union(diff.removed()).union(diff.unchanged()) - target_set = diff.added() - else: - ref_set = diff.unchanged().union(diff.removed()) - target_set = diff.added().union(diff.changed()) - state = dictfilter(diff.ref, ref_set) - state.update(dictfilter(diff.target, target_set)) - l.save(helpers.unflatten(state)) + """Update local YAML file with changes in the SSM Parameter Store (per configs in args)""" + remote, local = configure_endpoints(args) + local.save(remote.pull(local.get())) def apply(args): - r, _, diff = plan(args) - + """Apply local changes to the SSM Parameter Store""" + remote, local = configure_endpoints(args) print("\nApplying changes...") try: - r.apply(diff) + remote.push(local.get()) except Exception as e: print("Failed to apply changes to remote:", e) print("Done.") def plan(args): - r, l = RemoteState(args.profile), LocalState(args.filename) - diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path)) + """Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)""" + remote, local = configure_endpoints(args) + diff = remote.dry_run(local.get()) if diff.differ: - diff.print_state() + print(diff.describe_diff()) else: print("Remote state is up to date.") - return r, l, diff + return remote, local, diff if __name__ == "__main__": diff --git a/states/helpers.py b/states/helpers.py index 08d313a..e893503 100644 --- a/states/helpers.py +++ b/states/helpers.py @@ -1,55 +1,101 @@ -from termcolor import colored -from copy import deepcopy import collections +from copy import deepcopy +from functools import partial +from termcolor import colored -class FlatDictDiffer(object): - def __init__(self, ref, target): - self.ref, self.target = ref, target - self.ref_set, self.target_set = set(ref.keys()), set(target.keys()) - self.isect = self.ref_set.intersection(self.target_set) + +class DiffResolver(object): + """Determines diffs between two dicts, where the remote copy is considered the baseline""" + def __init__(self, remote, local, force=False): + self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local) + self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys()) + self.intersection = self.remote_set.intersection(self.local_set) + self.force = force if self.added() or self.removed() or self.changed(): self.differ = True else: self.differ = False + @classmethod + def configure(cls, *args, **kwargs): + return partial(cls, *args, **kwargs) + def added(self): - return self.target_set - self.isect + """Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}""" + return self.local_set - self.intersection def removed(self): - return self.ref_set - self.isect + """Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}""" + return self.remote_set - self.intersection def changed(self): - return set(k for k in self.isect if self.ref[k] != self.target[k]) + """Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}""" + return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k]) def unchanged(self): - return set(k for k in self.isect if self.ref[k] == self.target[k]) + """Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}""" + return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k]) - def print_state(self): + def describe_diff(self): + """Return a (multi-line) string describing all differences""" + description = "" for k in self.added(): - print(colored("+", 'green'), "{} = {}".format(k, self.target[k])) + description += colored("+", 'green'), "{} = {}".format(k, self.local_flat[k]) + '\n' for k in self.removed(): - print(colored("-", 'red'), k) + description += colored("-", 'red'), k + '\n' for k in self.changed(): - print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k])) - - -def flatten(d, pkey='', sep='/'): - items = [] - for k in d: - new = pkey + sep + k if pkey else k - if isinstance(d[k], collections.MutableMapping): - items.extend(flatten(d[k], new, sep=sep).items()) + description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.remote_flat[k], self.local_flat[k]) + '\n' + + return description + + def _flatten(self, d, current_path='', sep='/'): + """Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}""" + items = [] + for k in d: + new = current_path + sep + k if current_path else k + if isinstance(d[k], collections.MutableMapping): + items.extend(self._flatten(d[k], new, sep=sep).items()) + else: + items.append((sep + new, d[k])) + return dict(items) + + def _unflatten(self, d, sep='/'): + """Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure""" + output = {} + for k in d: + add( + obj=output, + path=k, + value=d[k], + sep=sep, + ) + return output + + def merge(self): + """Generate a merge of the local and remote dicts, following configurations set during __init__""" + dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)]) + if self.force: + # Overwrite local changes (i.e. only preserve added keys) + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add + prior_set = self.changed().union(self.removed()).union(self.unchanged()) + current_set = self.added() else: - items.append((sep + new, d[k])) - return dict(items) - - -def add(obj, path, value): - parts = path.strip("/").split("/") + # Preserve added keys and changed keys + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add + prior_set = self.unchanged().union(self.removed()) + current_set = self.added().union(self.changed()) + state = dictfilter(original=self.remote_flat, keep_keys=prior_set) + state.update(dictfilter(original=self.local_flat, keep_keys=current_set)) + return self._unflatten(state) + + +def add(obj, path, value, sep='/'): + """Add value to the `obj` dict at the specified path""" + parts = path.strip(sep).split(sep) last = len(parts) - 1 for index, part in enumerate(parts): if index == last: @@ -61,7 +107,7 @@ def add(obj, path, value): def search(state, path): result = state for p in path.strip("/").split("/"): - if result.get(p): + if result.clone(p): result = result[p] else: result = {} @@ -71,16 +117,6 @@ def search(state, path): return output -def unflatten(d): - output = {} - for k in d: - add( - obj=output, - path=k, - value=d[k]) - return output - - def merge(a, b): if not isinstance(b, dict): return b diff --git a/states/states.py b/states/states.py index bb96897..e51a2c3 100644 --- a/states/states.py +++ b/states/states.py @@ -1,11 +1,14 @@ from __future__ import print_function -from botocore.exceptions import ClientError, NoCredentialsError -from .helpers import flatten, merge, add, search + import sys -import os -import yaml + import boto3 import termcolor +import yaml +from botocore.exceptions import ClientError, NoCredentialsError + +from .helpers import merge, add, search + def str_presenter(dumper, data): if len(data.splitlines()) == 1 and data[-1] == '\n': @@ -17,8 +20,10 @@ def str_presenter(dumper, data): return dumper.represent_scalar( 'tag:yaml.org,2002:str', data.strip()) + yaml.SafeDumper.add_representer(str, str_presenter) + class SecureTag(yaml.YAMLObject): yaml_tag = u'!secure' @@ -38,7 +43,7 @@ def __hash__(self): return hash(self.secure) def __ne__(self, other): - return (not self.__eq__(other)) + return not self.__eq__(other) @classmethod def from_yaml(cls, loader, node): @@ -50,25 +55,28 @@ def to_yaml(cls, dumper, data): return dumper.represent_scalar(cls.yaml_tag, data.secure, style='|') return dumper.represent_scalar(cls.yaml_tag, data.secure) + yaml.SafeLoader.add_constructor('!secure', SecureTag.from_yaml) yaml.SafeDumper.add_multi_representer(SecureTag, SecureTag.to_yaml) class LocalState(object): - def __init__(self, filename): + """Encodes/decodes a dictionary to/from a YAML file""" + def __init__(self, filename, paths=('/',)): self.filename = filename + self.paths = paths - def get(self, paths, flat=True): + def get(self): try: output = {} - with open(self.filename,'rb') as f: - l = yaml.safe_load(f.read()) - for path in paths: + with open(self.filename, 'rb') as f: + local = yaml.safe_load(f.read()) + for path in self.paths: if path.strip('/'): - output = merge(output, search(l, path)) + output = merge(output, search(local, path)) else: - return flatten(l) if flat else l - return flatten(output) if flat else output + return local + return output except IOError as e: print(e, file=sys.stderr) if e.errno == 2: @@ -90,20 +98,23 @@ def save(self, state): class RemoteState(object): - def __init__(self, profile): + """Encodes/decodes a dict to/from the SSM Parameter Store""" + def __init__(self, profile, diff_class, paths=('/',)): if profile: boto3.setup_default_session(profile_name=profile) self.ssm = boto3.client('ssm') + self.diff_class = diff_class + self.paths = paths - def get(self, paths=['/'], flat=True): + def clone(self): p = self.ssm.get_paginator('get_parameters_by_path') output = {} - for path in paths: + for path in self.paths: try: for page in p.paginate( - Path=path, - Recursive=True, - WithDecryption=True): + Path=path, + Recursive=True, + WithDecryption=True): for param in page['Parameters']: add(obj=output, path=param['Name'], @@ -111,32 +122,45 @@ def get(self, paths=['/'], flat=True): except (ClientError, NoCredentialsError) as e: print("Failed to fetch parameters from SSM!", e, file=sys.stderr) - return flatten(output) if flat else output + return output + # noinspection PyMethodMayBeStatic def _read_param(self, value, ssm_type='String'): return SecureTag(value) if ssm_type == 'SecureString' else str(value) - def apply(self, diff): + def pull(self, local): + diff = self.diff_class( + remote=self.clone(), + local=local, + force=self.force + ) + return diff.merge() + def dry_run(self, local): + return self.diff_class(self.clone(), local, force=self.force) + + def push(self, local): + diff = self.dry_run(local) + + # diff.added|removed|changed return a "flattened" dict i.e. {"full/path": "value", ...} for k in diff.added(): ssm_type = 'String' - if isinstance(diff.target[k], list): + if isinstance(diff.local[k], list): ssm_type = 'StringList' - if isinstance(diff.target[k], SecureTag): + if isinstance(diff.local[k], SecureTag): ssm_type = 'SecureString' self.ssm.put_parameter( Name=k, - Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]), + Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]), Type=ssm_type) for k in diff.removed(): self.ssm.delete_parameter(Name=k) for k in diff.changed(): - ssm_type = 'SecureString' if isinstance(diff.target[k], SecureTag) else 'String' - + ssm_type = 'SecureString' if isinstance(diff.local[k], SecureTag) else 'String' self.ssm.put_parameter( Name=k, - Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]), + Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]), Overwrite=True, Type=ssm_type) diff --git a/states/tests.py b/states/tests.py new file mode 100644 index 0000000..4108b4f --- /dev/null +++ b/states/tests.py @@ -0,0 +1,150 @@ +from unittest import TestCase + +from . import helpers + + +class FlatDictDiffer(TestCase): + + def setUp(self) -> None: + self.obj = helpers.DiffResolver({}, {}) + + def test_flatten_single(self): + nested = { + "key": "value" + } + flat = { + "/key": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested) + ) + self.assertEqual( + nested, + self.obj._unflatten(flat) + ) + + def test_flatten_nested(self): + nested = { + "key1": { + "key2": "value" + } + } + flat = { + "/key1/key2": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested) + ) + self.assertEqual( + nested, + self.obj._unflatten(flat) + ) + + def test_flatten_nested_sep(self): + nested = { + "key1": { + "key2": "value" + } + } + flat = { + "\\key1\\key2": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested, sep='\\') + ) + self.assertEqual( + nested, + self.obj._unflatten(flat, sep='\\') + ) + + +class Pull(TestCase): + + def test_add_remote(self): + """Remote additions should be added to local""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + + plan = helpers.DiffResolver( + remote, + local, + ) + + self.assertEqual( + remote, + plan.merge() + ) + + def test_add_local(self): + """Local additions should be preserved so we won't see any changes to local""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + + diff = helpers.DiffResolver( + remote, + local, + ) + + self.assertEqual( + local, + diff.merge() + ) + + def test_change_local_force(self): + """Local changes should be overwritten if force+True""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d_new'}}, + } + + diff = helpers.DiffResolver.configure(force=True)( + remote, + local, + ) + + self.assertEqual( + remote, + diff.merge() + ) + + def test_change_local_no_force(self): + """Local changes should be preserved if force=False""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d_new'}}, + } + + diff = helpers.DiffResolver.configure(force=False)( + remote, + local, + ) + + self.assertEqual( + local, + diff.merge() + )