diff --git a/cloudferrylib/config.py b/cloudferrylib/config.py new file mode 100644 index 00000000..8229b340 --- /dev/null +++ b/cloudferrylib/config.py @@ -0,0 +1,261 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import contextlib +import logging + +import marshmallow +from marshmallow import fields + +from cloudferrylib.os import clients +from cloudferrylib.utils import bases +from cloudferrylib.utils import query +from cloudferrylib.utils import remote +from cloudferrylib.utils import utils + +LOG = logging.getLogger(__name__) +MODEL_LIST = [ + 'cloudferrylib.os.discovery.keystone.Tenant', + 'cloudferrylib.os.discovery.glance.Image', + 'cloudferrylib.os.discovery.cinder.Volume', + 'cloudferrylib.os.discovery.nova.Server', +] + + +class DictField(fields.Field): + def __init__(self, key_field, nested_field, **kwargs): + super(DictField, self).__init__(**kwargs) + self.key_field = key_field + self.nested_field = nested_field + + def _deserialize(self, value, attr, data): + if not isinstance(value, dict): + self.fail('type') + + ret = {} + for key, val in value.items(): + k = self.key_field.deserialize(key) + v = self.nested_field.deserialize(val) + ret[k] = v + return ret + + +class FirstFit(fields.Field): + def __init__(self, *args, **kwargs): + many = kwargs.pop('many', False) + super(FirstFit, self).__init__(**kwargs) + self.many = many + self.variants = args + + def _deserialize(self, value, attr, data): + if self.many: + return [self._do_deserialize(v) for v in value] + else: + return self._do_deserialize(value) + + def _do_deserialize(self, value): + errors = [] + for field in self.variants: + try: + return field.deserialize(value) + except marshmallow.ValidationError as ex: + errors.append(ex) + raise marshmallow.ValidationError([e.messages for e in errors]) + + +class OneOrMore(fields.Field): + def __init__(self, base_type, **kwargs): + super(OneOrMore, self).__init__(**kwargs) + self.base_type = base_type + + def _deserialize(self, value, attr, data): + # pylint: disable=protected-access + if isinstance(value, collections.Sequence) and \ + not isinstance(value, basestring): + return [self.base_type._deserialize(v, attr, data) + for v in value] + else: + return [self.base_type._deserialize(value, attr, data)] + + +class SshSettings(bases.Hashable, bases.Representable, + bases.ConstructableFromDict): + class Schema(marshmallow.Schema): + username = fields.String() + sudo_password = fields.String(missing=None) + gateway = fields.String(missing=None) + connection_attempts = fields.Integer(missing=1) + cipher = fields.String(missing=None) + key_file = fields.String(missing=None) + + @marshmallow.post_load + def to_scope(self, data): + return Scope(data) + + +class Scope(bases.Hashable, bases.Representable, bases.ConstructableFromDict): + class Schema(marshmallow.Schema): + project_name = fields.String(missing=None) + project_id = fields.String(missing=None) + domain_id = fields.String(missing=None) + + @marshmallow.post_load + def to_scope(self, data): + return Scope(data) + + @marshmallow.validates_schema(skip_on_field_errors=True) + def check_migration_have_correct_source_and_dict(self, data): + if all(data[k] is None for k in self.declared_fields.keys()): + raise marshmallow.ValidationError( + 'At least one of {keys} shouldn\'t be None'.format( + keys=self.declared_fields.keys())) + + +class Credential(bases.Hashable, bases.Representable, + bases.ConstructableFromDict): + class Schema(marshmallow.Schema): + auth_url = fields.Url() + username = fields.String() + password = fields.String() + region_name = fields.String(missing=None) + domain_id = fields.String(missing=None) + https_insecure = fields.Boolean(missing=False) + https_cacert = fields.String(missing=None) + endpoint_type = fields.String(missing='admin') + + @marshmallow.post_load + def to_credential(self, data): + return Credential(data) + + +class OpenstackCloud(bases.Hashable, bases.Representable, + bases.ConstructableFromDict): + class Schema(marshmallow.Schema): + credential = fields.Nested(Credential.Schema) + scope = fields.Nested(Scope.Schema) + ssh_settings = fields.Nested(SshSettings.Schema, load_from='ssh') + discover = OneOrMore(fields.String(), default=MODEL_LIST) + + @marshmallow.post_load + def to_cloud(self, data): + return OpenstackCloud(data) + + def __init__(self, data): + super(OpenstackCloud, self).__init__(data) + self.name = None + + def image_client(self, scope=None): + # pylint: disable=no-member + return clients.image_client(self.credential, scope or self.scope) + + def identity_client(self, scope=None): + # pylint: disable=no-member + return clients.identity_client(self.credential, scope or self.scope) + + def volume_client(self, scope=None): + # pylint: disable=no-member + return clients.volume_client(self.credential, scope or self.scope) + + def compute_client(self, scope=None): + # pylint: disable=no-member + return clients.compute_client(self.credential, scope or self.scope) + + @contextlib.contextmanager + def remote_executor(self, hostname, key_file=None, ignore_errors=False): + # pylint: disable=no-member + key_files = [] + settings = self.ssh_settings + if settings.key_file is not None: + key_files.append(settings.key_file) + if key_file is not None: + key_files.append(key_file) + if key_files: + utils.ensure_ssh_key_added(key_files) + try: + yield remote.RemoteExecutor( + hostname, settings.username, + sudo_password=settings.sudo_password, + gateway=settings.gateway, + connection_attempts=settings.connection_attempts, + cipher=settings.cipher, + key_file=settings.key_file, + ignore_errors=ignore_errors) + finally: + remote.RemoteExecutor.close_connection(hostname) + + +class Migration(bases.Hashable, bases.Representable): + class Schema(marshmallow.Schema): + source = fields.String(required=True) + destination = fields.String(required=True) + objects = DictField( + fields.String(), + FirstFit( + fields.String(), + DictField( + fields.String(), + OneOrMore(fields.Raw())), + many=True), + required=True) + + @marshmallow.post_load + def to_migration(self, data): + return Migration(**data) + + def __init__(self, source, destination, objects): + self.source = source + self.destination = destination + self.query = query.Query(objects) + + +class Configuration(bases.Hashable, bases.Representable, + bases.ConstructableFromDict): + class Schema(marshmallow.Schema): + clouds = DictField( + fields.String(allow_none=False), + fields.Nested(OpenstackCloud.Schema, default=dict)) + migrations = DictField( + fields.String(allow_none=False), + fields.Nested(Migration.Schema, default=dict), debug=True) + + @marshmallow.validates_schema(skip_on_field_errors=True) + def check_migration_have_correct_source_and_dict(self, data): + clouds = data['clouds'] + migrations = data['migrations'] + for migration_name, migration in migrations.items(): + if migration.source not in clouds: + raise marshmallow.ValidationError( + 'Migration "{0}" source "{1}" should be defined ' + 'in clouds'.format(migration_name, migration.source)) + if migration.destination not in clouds: + raise marshmallow.ValidationError( + 'Migration "{0}" destination "{1}" should be defined ' + 'in clouds'.format(migration_name, + migration.destination)) + + @marshmallow.post_load + def to_configuration(self, data): + for name, cloud in data['clouds'].items(): + cloud.name = name + return Configuration(data) + + +def load(data): + """ + Loads and validates configuration + :param data: dictionary file loaded from discovery YAML + :return: Configuration instance + """ + schema = Configuration.Schema(strict=True) + return schema.load(data).data diff --git a/cloudferrylib/os/clients.py b/cloudferrylib/os/clients.py index 5951f768..59ee2cc0 100644 --- a/cloudferrylib/os/clients.py +++ b/cloudferrylib/os/clients.py @@ -29,56 +29,6 @@ _endpoints = {} -class Hashable(object): - def _fields(self): - return (f for f in dir(self) if not f.startswith('_')) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - for field in self._fields(): - if getattr(self, field) != getattr(other, field, None): - return False - return True - - def __ne__(self, other): - return not (self == other) - - def __hash__(self): - return hash(tuple(getattr(self, f) for f in self._fields())) - - def __repr__(self): - cls = self.__class__ - return '<{module}.{cls} {fields}>'.format( - module=cls.__module__, - cls=cls.__name__, - fields=' '.join('{0}={1}'.format(f, repr(getattr(self, f))) - for f in self._fields() - if getattr(self, f) is not None)) - - -class Scope(Hashable): - def __init__(self, project_id=None, project_name=None, domain_id=None): - self.project_name = project_name - self.project_id = project_id - self.domain_id = domain_id - - -class Credential(Hashable): - def __init__(self, auth_url, username, password, - region_name=None, domain_id=None, - https_insecure=False, https_cacert=None, - endpoint_type=consts.EndpointType.ADMIN): - self.auth_url = auth_url - self.username = username - self.password = password - self.region_name = region_name - self.domain_id = domain_id - self.https_insecure = https_insecure - self.https_cacert = https_cacert - self.endpoint_type = endpoint_type - - class ClientProxy(object): def __init__(self, factory_fn, credential, scope, token=None, endpoint=None, path=None, service_type=None): @@ -108,6 +58,7 @@ def __getattr__(self, name): return attr def __call__(self, *args, **kwargs): + # pylint: disable=broad-except for retry in (True, False): try: method = self._get_attr(self._path) @@ -133,8 +84,8 @@ def _get_authenticated_v2_client(credential, scope): region_name=credential.region_name, domain_id=credential.domain_id, endpoint_type=credential.endpoint_type, - https_insecure=credential.https_insecure, - https_cacert=credential.https_cacert, + insecure=credential.https_insecure, + cacert=credential.https_cacert, project_domain_id=scope.domain_id, project_name=scope.project_name, project_id=scope.project_id, diff --git a/cloudferrylib/os/compute/nova_compute.py b/cloudferrylib/os/compute/nova_compute.py index 6ebc6ba1..2393d15f 100644 --- a/cloudferrylib/os/compute/nova_compute.py +++ b/cloudferrylib/os/compute/nova_compute.py @@ -383,6 +383,8 @@ def convert_instance(instance, cfg, cloud): else: server_group = None + config_drive = utl.get_disk_path(instance, instance_block_info, + disk=utl.DISK_CONFIG) inst = {'instance': {'name': instance.name, 'instance_name': instance_name, 'id': instance.id, @@ -408,7 +410,8 @@ def convert_instance(instance, cfg, cloud): 'is_ephemeral': is_ephemeral, 'volumes': volumes, 'user_id': instance.user_id, - 'server_group': server_group + 'server_group': server_group, + 'config_drive': config_drive is not None, }, 'ephemeral': ephemeral_path, 'diff': diff, @@ -685,6 +688,7 @@ def _deploy_instances(self, info_compute): instance, 'key_name'), 'nics': instance['nics'], 'image': instance['image_id'], + 'config_drive': instance['config_drive'], # user_id matches user_id on source 'user_id': instance.get('user_id'), 'availability_zone': self.attr_override.get_attr( diff --git a/cloudferrylib/os/consts.py b/cloudferrylib/os/consts.py index fe575138..d89344f0 100644 --- a/cloudferrylib/os/consts.py +++ b/cloudferrylib/os/consts.py @@ -34,12 +34,6 @@ def items(cls): return [(f, getattr(cls, f)) for f in cls.names()] -class EndpointType(EnumType): - INTERNAL = "internal" - ADMIN = "admin" - PUBLIC = "public" - - class ServiceType(EnumType): IDENTITY = 'identity' COMPUTE = 'compute' diff --git a/cloudferrylib/os/discovery/cinder.py b/cloudferrylib/os/discovery/cinder.py index 7b424693..a2262b9b 100644 --- a/cloudferrylib/os/discovery/cinder.py +++ b/cloudferrylib/os/discovery/cinder.py @@ -28,10 +28,11 @@ class Schema(model.Schema): device = fields.String(required=True) +@model.type_alias('volumes') class Volume(model.Model): class Schema(model.Schema): object_id = model.PrimaryKey('id') - name = fields.String(required=True) + name = fields.String(required=True, allow_none=True) description = fields.String(required=True, allow_none=True) availability_zone = fields.String(required=True) encrypted = fields.Boolean(missing=False) @@ -65,7 +66,7 @@ def discover(cls, cloud): volume_client = cloud.volume_client() volumes_list = volume_client.volumes.list( search_opts={'all_tenants': True}) - with model.Transaction() as tx: + with model.Session() as session: for raw_volume in volumes_list: volume = Volume.load_from_cloud(cloud, raw_volume) - tx.store(volume) + session.store(volume) diff --git a/cloudferrylib/os/discovery/glance.py b/cloudferrylib/os/discovery/glance.py index 2d763784..29ba941a 100644 --- a/cloudferrylib/os/discovery/glance.py +++ b/cloudferrylib/os/discovery/glance.py @@ -31,15 +31,15 @@ class Schema(model.Schema): @classmethod def load_from_cloud(cls, cloud, data, overrides=None): - return cls.get(cloud, data.image_id, data.member_id) + return cls._get(cloud, data.image_id, data.member_id) @classmethod def load_missing(cls, cloud, object_id): image_id, member_id = object_id.id.split(':') - return cls.get(cls, image_id, member_id) + return cls._get(cls, image_id, member_id) @classmethod - def get(cls, cloud, image_id, member_id): + def _get(cls, cloud, image_id, member_id): return super(ImageMember, cls).load_from_cloud(cloud, { 'object_id': '{0}:{1}'.format(image_id, member_id), 'image': image_id, @@ -47,6 +47,7 @@ def get(cls, cloud, image_id, member_id): }) +@model.type_alias('images') class Image(model.Model): class Schema(model.Schema): object_id = model.PrimaryKey('id') @@ -78,17 +79,17 @@ def load_missing(cls, cloud, object_id): @classmethod def discover(cls, cloud): image_client = cloud.image_client() - with model.Transaction() as tx: + with model.Session() as session: for raw_image in image_client.images.list( filters={"is_public": None}): try: image = Image.load_from_cloud(cloud, raw_image) - tx.store(image) + session.store(image) members_list = image_client.image_members.list( image=raw_image) for raw_member in members_list: member = ImageMember.load_from_cloud(cloud, raw_member) - tx.store(member) + session.store(member) image.members.append(member) except exceptions.ValidationError as e: LOG.warning('Invalid image %s: %s', raw_image.id, e) diff --git a/cloudferrylib/os/discovery/keystone.py b/cloudferrylib/os/discovery/keystone.py index 49e8922d..d2d4e5c5 100644 --- a/cloudferrylib/os/discovery/keystone.py +++ b/cloudferrylib/os/discovery/keystone.py @@ -21,6 +21,7 @@ LOG = logging.getLogger(__name__) +@model.type_alias('tenants') class Tenant(model.Model): class Schema(model.Schema): object_id = model.PrimaryKey('id') @@ -40,6 +41,6 @@ def load_missing(cls, cloud, object_id): @classmethod def discover(cls, cloud): identity_client = cloud.identity_client() - with model.Transaction() as tx: + with model.Session() as session: for tenant in identity_client.tenants.list(): - tx.store(Tenant.load_from_cloud(cloud, tenant)) + session.store(Tenant.load_from_cloud(cloud, tenant)) diff --git a/cloudferrylib/os/discovery/model.py b/cloudferrylib/os/discovery/model.py index f16ba257..8be5adb6 100644 --- a/cloudferrylib/os/discovery/model.py +++ b/cloudferrylib/os/discovery/model.py @@ -72,13 +72,13 @@ def discover(cls, cloud): volume_client = cloud.volume_client() volumes_list = volume_client.volumes.list( search_opts={'all_tenants': True}) - with model.Transaction() as tx: + with model.Session() as session: for raw_volume in volumes_list: volume = Volume.load_from_cloud(cloud, raw_volume) - tx.store(volume) + session.store(volume) -Example using ``Transaction`` class to store and retrieve data from database:: +Example using ``Session`` class to store and retrieve data from database:: from cloudferrylib.os.discovery import model @@ -97,20 +97,20 @@ class Schema(model.Schema): }, 'name': 'foobar' }) - with model.Transaction() as tx: - tx.store(new_tenant) + with model.Session() as session: + session.store(new_tenant) # Retrieving previously stored item - with model.Transaction() as tx: + with model.Session() as session: object_id = model.ObjectId('ed388ba9-dea3-4017-987b-92f7915f33bb', 'us-west1') - stored_tenant = tx.retrieve(Tenant, object_id) + stored_tenant = session.retrieve(Tenant, object_id) assert stored_tenant.name == 'foobar' # Getting list of items - with model.Transaction() as tx: + with model.Session() as session: found_tenant = None - for tenant in tx.list(Tenant): + for tenant in session.list(Tenant): if tenant.id == object_id: found_tenant = tenant assert found_tenant is not None @@ -119,17 +119,21 @@ class Schema(model.Schema): """ import collections +import contextlib import json import logging -import sqlite3 +import sys import threading import marshmallow from marshmallow import fields from oslo_utils import importutils +from cloudferrylib.utils import local_db + LOG = logging.getLogger(__name__) -CREATE_OBJECT_TABLE_SQL = """ +type_aliases = {} +local_db.execute_once(""" CREATE TABLE IF NOT EXISTS objects ( uuid TEXT, cloud TEXT, @@ -137,9 +141,7 @@ class Schema(model.Schema): json TEXT, PRIMARY KEY (uuid, cloud, type) ) -""" - -registry = {} +""") class ObjectId(collections.namedtuple('ObjectId', ('id', 'cloud'))): @@ -285,6 +287,7 @@ def create(cls, values, schema=None, mark_dirty=False): database when transaction completes) :return: instance of model class """ + # pylint: disable=protected-access if schema is None: schema = cls.get_schema() @@ -412,6 +415,7 @@ def load_missing(cls, cloud, object_id): :param object_id: identifier of missing object :return: model class instance """ + # pylint: disable=unused-argument raise NotFound(cls, object_id) @classmethod @@ -423,6 +427,7 @@ def discover(cls, cloud): clients, etc... :return: model class instance """ + # pylint: disable=unused-argument return def is_dirty(self): @@ -455,6 +460,29 @@ def mark_dirty(self): """ self._original.clear() + def clear_dirty(self): + """ + Update internal state of object so it's not considered dirty (e.g. + don't need to be saved to database). + """ + schema = self.get_schema() + for name, field in schema.fields.items(): + if isinstance(field, Nested): + value = getattr(self, name, None) + if value is not None: + if field.many: + for elem in value: + elem.clear_dirty() + else: + value.clear_dirty() + else: + value = getattr(self, name, None) + if isinstance(field, Reference): + self._original[name] = \ + field.get_significant_value(value) + else: + self._original[name] = value + def dependencies(self): """ Return list of other model instances that current object depend upon. @@ -476,21 +504,21 @@ def find(cls, cloud, object_id): :param object_id: object identifier :return object or None if it's missing even in cloud """ - with Transaction() as tx: + with Session.current() as session: try: - if tx.is_missing(cls, object_id): + if session.is_missing(cls, object_id): return None else: - return tx.retrieve(cls, object_id) + return session.retrieve(cls, object_id) except NotFound: LOG.debug('Trying to load missing %s value: %s', _type_name(cls), object_id) obj = cls.load_missing(cloud, object_id) if obj is None: - tx.store_missing(cls, object_id) + session.store_missing(cls, object_id) return None else: - tx.store(obj) + session.store(obj) return obj @classmethod @@ -500,6 +528,12 @@ def get_class(cls): """ return cls + def get(self, name, default=None): + """ + Returns object attribute by name. + """ + return getattr(self, name, default) + def __repr__(self): schema = self.get_schema() obj_fields = sorted(schema.declared_fields.keys()) @@ -635,8 +669,14 @@ def __repr__(self): def _retrieve_obj(self): if self._object is None: - with Transaction() as tx: - self._object = tx.retrieve(self._model, self._object_id) + with Session.current() as session: + self._object = session.retrieve(self._model, self._object_id) + + def get(self, name): + """ + Returns object attribute by name. + """ + return getattr(self, name, None) def get_class(self): """ @@ -645,65 +685,71 @@ def get_class(self): return self._model -class Transaction(object): +class Session(object): """ - Transaction objects are used to store and retrieve objects to database. It + Session objects are used to store and retrieve objects to database. It tracks already loaded object to prevent loading same object twice and to prevent losing changes made to already loaded objects. - Transactions should be used as context managers (e.g. inside ``with`` - block). On exit from this block all changes made using transaction will + Sessions should be used as context managers (e.g. inside ``with`` + block). On exit from this block all changes made using session will be saved to disk. """ - tls = threading.local() - tls.current = None + _tls = threading.local() + _tls.current = None def __init__(self): self.session = None - self.connection = None - self.cursor = None + self.previous = None + self.tx = None def __enter__(self): - if self.tls.current is not None: - # TODO: use save points for nested transactions - current_tx = self.tls.current - self.connection = current_tx.connection - self.cursor = current_tx.cursor - self.session = current_tx.session - return self - filepath = 'migration_data.db' - self.connection = sqlite3.connect(filepath) - self.connection.isolation_level = None - self.cursor = self.connection.cursor() - self.cursor.execute('BEGIN') - self.cursor.execute(CREATE_OBJECT_TABLE_SQL) - self.tls.current = self + # pylint: disable=protected-access + self.previous = self._tls.current + self._tls.current = self + if self.previous is not None: + # Store outer TX values for savepoint + self.previous._dump_objects() + self.tx = local_db.Transaction() + self.tx.__enter__() self.session = {} return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.tls.current is not self: - # TODO: use save points for nested transactions - return - - self.tls.current = None - if exc_type is not None or exc_val is not None or exc_tb is not None: - self.cursor.execute('ROLLBACK') - return + self._tls.current = self.previous try: - for (cls, pk), obj in self.session.items(): - if obj is None: - self._store_none(cls, pk) - continue - if obj.is_dirty(): - self._update_row(obj) - self.cursor.execute('COMMIT') + if exc_type is None and exc_val is None and exc_tb is None: + self._dump_objects() except Exception: - self.cursor.execute('ROLLBACK') + LOG.error('Exception dumping objects', exc_info=True) + exc_type, exc_val, exc_tb = sys.exc_info() raise finally: - self.cursor.close() - self.connection.close() + self.tx.__exit__(exc_type, exc_val, exc_tb) + + def _dump_objects(self): + for (cls, pk), obj in self.session.items(): + if obj is None: + self._store_none(cls, pk) + continue + if obj.is_dirty(): + self._update_row(obj) + + @classmethod + def current(cls): + """ + Returns current session or create new session if there is no session + started yet. + :return: Session instance + """ + current = cls._tls.current + if current is not None: + @contextlib.contextmanager + def noop_ctx_mgr(): + yield current + return noop_ctx_mgr() + else: + return Session() def store(self, obj): """ @@ -738,10 +784,11 @@ def retrieve(self, cls, object_id): key = (cls, object_id) if key in self.session: return self.session[key] - self.cursor.execute('SELECT json FROM objects WHERE uuid=? AND ' - 'cloud=? AND type=?', - (object_id.id, object_id.cloud, _type_name(cls))) - result = self.cursor.fetchone() + result = self.tx.query_one('SELECT json FROM objects WHERE uuid=:uuid ' + 'AND cloud=:cloud AND type=:type_name', + uuid=object_id.id, + cloud=object_id.cloud, + type_name=_type_name(cls)) if not result or not result[0]: raise NotFound(cls, object_id) obj = cls.load(json.loads(result[0])) @@ -758,10 +805,11 @@ def is_missing(self, cls, object_id): key = (cls, object_id) if key in self.session: return self.session[key] is None - self.cursor.execute('SELECT json FROM objects WHERE uuid=? AND ' - 'cloud=? AND type=?', - (object_id.id, object_id.cloud, _type_name(cls))) - result = self.cursor.fetchone() + result = self.tx.query_one('SELECT json FROM objects WHERE uuid=:uuid ' + 'AND cloud=:cloud AND type=:type_name', + uuid=object_id.id, + cloud=object_id.cloud, + type_name=_type_name(cls)) if not result: raise NotFound(cls, object_id) return result[0] is None @@ -775,20 +823,19 @@ def list(self, cls, cloud=None): :return: list of model instances """ if cloud is None: - self.cursor.execute('SELECT uuid, cloud, json ' - 'FROM objects WHERE type=?', - (_type_name(cls),)) + query = 'SELECT uuid, cloud, json ' \ + 'FROM objects WHERE type=:type_name' else: - self.cursor.execute('SELECT uuid, cloud, json ' - 'FROM objects WHERE cloud=? AND type=?', - (cloud, _type_name(cls))) + query = 'SELECT uuid, cloud, json ' \ + 'FROM objects WHERE cloud=:cloud AND type=:type_name' result = [] for obj in self.session.values(): if isinstance(obj, cls) and \ (cloud is None or cloud == obj.primary_key.cloud): result.append(obj) - for row in self.cursor.fetchall(): + for row in self.tx.query(query, type_name=_type_name(cls), + cloud=cloud): uuid, cloud, json_data = row key = (cls, ObjectId(uuid, cloud)) if key in self.session or not json_data: @@ -798,22 +845,93 @@ def list(self, cls, cloud=None): result.append(obj) return result + def delete(self, cls=None, cloud=None, object_id=None): + """ + Deletes all objects that have cls or cloud or object_id that are equal + to values passed as arguments. Arguments that are None are ignored. + """ + if cloud is not None and object_id is not None: + assert object_id.cloud == cloud + for key in self.session.keys(): + obj_cls, obj_pk = key + matched = True + if cls is not None and cls is not obj_cls: + matched = False + if cloud is not None and obj_pk.cloud != cloud: + matched = False + if object_id is not None and object_id != obj_pk: + matched = False + if matched: + del self.session[key] + self._delete_rows(cls, cloud, object_id) + def _update_row(self, obj): pk = obj.primary_key uuid = pk.id cloud = pk.cloud type_name = _type_name(obj.get_class()) - self.cursor.execute('INSERT OR REPLACE INTO objects ' - 'VALUES (?, ?, ?, ?)', - (uuid, cloud, type_name, json.dumps(obj.dump()))) + self.tx.execute('INSERT OR REPLACE INTO objects ' + 'VALUES (:uuid, :cloud, :type_name, :data)', + uuid=uuid, cloud=cloud, type_name=type_name, + data=json.dumps(obj.dump())) + obj.clear_dirty() + assert not obj.is_dirty() def _store_none(self, cls, pk): uuid = pk.id cloud = pk.cloud type_name = _type_name(cls) - self.cursor.execute('INSERT OR REPLACE INTO objects ' - 'VALUES (?, ?, ?, NULL)', - (uuid, cloud, type_name)) + self.tx.execute('INSERT OR REPLACE INTO objects ' + 'VALUES (:uuid, :cloud, :type_name, NULL)', + uuid=uuid, cloud=cloud, type_name=type_name) + + def _delete_rows(self, cls, cloud, object_id): + predicates = [] + kwargs = {} + if cls is not None: + predicates.append('type=:type_name') + kwargs['type_name'] = _type_name(cls) + if object_id is not None: + predicates.append('uuid=:uuid') + kwargs['uuid'] = object_id.id + if cloud is None: + cloud = object_id.cloud + else: + assert cloud == object_id.cloud + if cloud is not None: + predicates.append('cloud=:cloud') + kwargs['cloud'] = cloud + statement = 'DELETE FROM objects WHERE' + if predicates: + statement += ' AND '.join(predicates) + else: + statement += ' 1' + self.tx.execute(statement, **kwargs) + + +def type_alias(name): + """ + Decorator function that add alias for some model class + :param name: alias name + """ + + def wrapper(cls): + assert issubclass(cls, Model) + type_aliases[name] = cls + return cls + return wrapper + + +def get_model(type_name): + """ + Return model class instance using either alias or fully qualified name. + :param type_name: alias or fully qualified class name + :return: subclass of Model + """ + if type_name in type_aliases: + return type_aliases[type_name] + else: + return importutils.import_class(type_name) def _type_name(cls): diff --git a/cloudferrylib/os/discovery/nova.py b/cloudferrylib/os/discovery/nova.py index 55f8be95..d082ff77 100644 --- a/cloudferrylib/os/discovery/nova.py +++ b/cloudferrylib/os/discovery/nova.py @@ -50,6 +50,7 @@ class Schema(model.Schema): size = fields.Integer(required=True) +@model.type_alias('vms') class Server(model.Model): class Schema(model.Schema): object_id = model.PrimaryKey('id') @@ -90,11 +91,11 @@ class Schema(model.Schema): def discover(cls, cloud): compute_client = cloud.compute_client() avail_hosts = list_available_compute_hosts(compute_client) - with model.Transaction() as tx: + with model.Session() as session: servers = [] # Collect servers using API - for tenant in tx.list(keystone.Tenant, cloud.name): + for tenant in session.list(keystone.Tenant, cloud.name): server_list = compute_client.servers.list( search_opts={ 'all_tenants': True, @@ -136,7 +137,7 @@ def discover(cls, cloud): ephemeral_disks = _list_ephemeral(remote, srv) if ephemeral_disks is not None: srv.ephemeral_disks = ephemeral_disks - tx.store(srv) + session.store(srv) def _list_ephemeral(remote, server): @@ -158,8 +159,7 @@ def _list_ephemeral(remote, server): if len(split) != 2: continue target, path = split - if target in volume_targets or not path.startswith('/') or \ - path.endswith('disk.config'): + if target in volume_targets or not path.startswith('/'): continue size_str = remote.sudo('stat -c %s {path}', path=path) if not size_str.succeeded: diff --git a/cloudferrylib/os/discovery/stages.py b/cloudferrylib/os/discovery/stages.py new file mode 100644 index 00000000..5838e908 --- /dev/null +++ b/cloudferrylib/os/discovery/stages.py @@ -0,0 +1,78 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from oslo_utils import importutils + +from cloudferrylib import stage +from cloudferrylib.os.discovery import model + +LOG = logging.getLogger(__name__) + + +class DiscoverStage(stage.Stage): + def __init__(self): + super(DiscoverStage, self).__init__() + self.missing_clouds = None + + def invalidate(self, old_signature, new_signature, force=False): + """ + Remove data related to any cloud that changed signature. + """ + if force: + with model.Session() as session: + session.delete() + return + + self.missing_clouds = [] + + # Create set of cloud names that which data is not valid anymore + old_clouds = set(old_signature.keys()) + invalid_clouds = old_clouds.difference(new_signature.keys()) + for name, signature in new_signature.items(): + if name not in old_signature: + self.missing_clouds.append(name) + continue + if old_signature[name] != signature: + self.missing_clouds.append(name) + invalid_clouds.add(name) + + with model.Session() as session: + for cloud in invalid_clouds: + session.delete(cloud=cloud) + + def signature(self, config): + """ + Discovery signature is based on configuration. Each configured cloud + have it's own signature. + """ + return {n: [c.credential.auth_url, c.credential.region_name] + for n, c in config.clouds.items()} + + def execute(self, config): + """ + Execute discovery. + """ + if self.missing_clouds is None: + self.missing_clouds = config.clouds.keys() + + for cloud_name in self.missing_clouds: + cloud = config.clouds[cloud_name] + for class_name in cloud.discover: + cls = importutils.import_class(class_name) + LOG.info('Starting discover %s objects in %s cloud', + cls.__name__, cloud_name) + cls.discover(cloud) + LOG.info('Done discovering %s objects in %s cloud', + cls.__name__, cloud_name) diff --git a/cloudferrylib/os/estimation/procedures.py b/cloudferrylib/os/estimation/procedures.py index c2eae340..a435cd00 100644 --- a/cloudferrylib/os/estimation/procedures.py +++ b/cloudferrylib/os/estimation/procedures.py @@ -20,20 +20,24 @@ from cloudferrylib.utils import sizeof_format -def list_filtered(tx, cls, cloud_name, tenant): - return (x for x in tx.list(cls, cloud_name) +def list_filtered(session, cls, cloud_name, tenant): + return (x for x in session.list(cls, cloud_name) if tenant is None or tenant == x.tenant.object_id.id) -def estimate_copy(cloud_name, tenant): - with model.Transaction() as tx: +def estimate_copy(cfg, migration_name): + migration = cfg.migrations[migration_name] + query = migration.query + src_cloud = migration.source + + with model.Session() as session: total_ephemeral_size = 0 total_volume_size = 0 total_image_size = 0 accounted_volumes = set() accounted_images = set() - for server in list_filtered(tx, nova.Server, cloud_name, tenant): + for server in query.search(session, src_cloud, nova.Server): for ephemeral_disk in server.ephemeral_disks: total_ephemeral_size += ephemeral_disk.size if server.image is not None \ @@ -44,13 +48,16 @@ def estimate_copy(cloud_name, tenant): if volume.object_id not in accounted_volumes: total_volume_size += volume.size accounted_volumes.add(volume.object_id) - for volume in list_filtered(tx, cinder.Volume, cloud_name, tenant): + + for volume in query.search(session, src_cloud, cinder.Volume): if volume.object_id not in accounted_volumes: total_volume_size += volume.size - for image in list_filtered(tx, glance.Image, cloud_name, tenant): + + for image in query.search(session, src_cloud, glance.Image): if image.object_id not in accounted_images: total_image_size += image.size + print 'Migration', migration_name, 'estimates:' print 'Images:' print ' Size:', sizeof_format.sizeof_fmt(total_image_size) print 'Ephemeral disks:' @@ -59,7 +66,7 @@ def estimate_copy(cloud_name, tenant): print ' Size:', sizeof_format.sizeof_fmt(total_volume_size, 'G') -def show_largest_servers(count, cloud_name, tenant): +def show_largest_servers(cfg, count, migration_name): def server_size(server): size = 0 if server.image is not None: @@ -71,15 +78,19 @@ def server_size(server): return size output = [] - with model.Transaction() as tx: + migration = cfg.migrations[migration_name] + with model.Session() as session: for index, server in enumerate( heapq.nlargest( count, - list_filtered(tx, nova.Server, cloud_name, tenant), + migration.query.search(session, migration.source, + nova.Server), key=server_size), start=1): output.append( - ' {0}. {1.object_id.id} {1.name}'.format(index, server)) + ' {0}. {1.object_id.id} {1.name} - {2}'.format( + index, server, + sizeof_format.sizeof_fmt(server_size(server)))) if output: print '\n{0} largest servers:'.format(len(output)) for line in output: @@ -87,10 +98,10 @@ def server_size(server): def show_largest_unused_resources(count, cloud_name, tenant): - with model.Transaction() as tx: + with model.Session() as session: used_volumes = set() used_images = set() - servers = list_filtered(tx, nova.Server, cloud_name, tenant) + servers = list_filtered(session, nova.Server, cloud_name, tenant) for server in servers: if server.image is not None: used_images.add(server.image.object_id) @@ -100,7 +111,7 @@ def show_largest_unused_resources(count, cloud_name, tenant): # Find unused volumes volumes_output = [] volumes_size = 0 - volumes = list_filtered(tx, cinder.Volume, cloud_name, tenant) + volumes = list_filtered(session, cinder.Volume, cloud_name, tenant) for index, volume in enumerate( heapq.nlargest(count, (v for v in volumes @@ -116,11 +127,12 @@ def show_largest_unused_resources(count, cloud_name, tenant): # Find unused images images_output = [] images_size = 0 - images = list_filtered(tx, glance.Image, cloud_name, tenant) + images = list_filtered(session, glance.Image, cloud_name, tenant) for index, image in enumerate( heapq.nlargest(count, (i for i in images - if i.object_id not in used_images)), + if i.object_id not in used_images), + key=lambda i: i.size), start=1): images_size += image.size size = sizeof_format.sizeof_fmt(image.size) diff --git a/cloudferrylib/os/image/glance_image.py b/cloudferrylib/os/image/glance_image.py index 19738819..bb1b6f40 100644 --- a/cloudferrylib/os/image/glance_image.py +++ b/cloudferrylib/os/image/glance_image.py @@ -228,9 +228,12 @@ def delete_image(self, image_id): self.glance_client.images.delete(image_id) def get_image_by_id(self, image_id): - for glance_image in self.get_image_list(): - if glance_image.id == image_id: - return glance_image + try: + return self.glance_client.images.get(image_id) + except glance_exceptions.NotFound: + LOG.warning('Image %s not found on %s', image_id, + self.cloud.position) + return None def get_image_by_name(self, image_name): for glance_image in self.get_image_list(): @@ -290,9 +293,8 @@ def convert(self, glance_image, cloud): resource = cloud.resources[utl.IMAGE_RESOURCE] keystone = cloud.resources["identity"] - gl_image = { - k: w for k, w in glance_image.to_dict().items( - ) if k in CREATE_PARAMS} + image_dict = glance_image.to_dict() + gl_image = {k: image_dict.get(k) for k in CREATE_PARAMS} # we need to pass resource to destination to copy image gl_image['resource'] = resource diff --git a/cloudferrylib/stage.py b/cloudferrylib/stage.py new file mode 100644 index 00000000..0ae83426 --- /dev/null +++ b/cloudferrylib/stage.py @@ -0,0 +1,104 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc + +from oslo_utils import importutils + +from cloudferrylib.utils import local_db + +local_db.execute_once(""" +CREATE TABLE IF NOT EXISTS stages ( + stage TEXT, + signature JSON, + PRIMARY KEY (stage) +) +""") + + +class Stage(object): + __metaclass__ = abc.ABCMeta + dependencies = [] + + @abc.abstractmethod + def signature(self, config): + """ + Returns signature for data that will be produced during this stage. If + the signature differ from the one stored in database, then invalidate + method will be called. + :param config: cloudferrylib.config.Configuration instance + :return: + """ + return + + @abc.abstractmethod + def execute(self, config): + """ + Should contain any code that is required to be executed during this + stage. + :param config: cloudferrylib.config.Configuration instance + """ + return + + @abc.abstractmethod + def invalidate(self, old_signature, new_signature, force=False): + """ + Should destroy any stale data based on signature difference. + :param old_signature: old signature stored in DB + :param new_signature: new signature + """ + return + + +def execute_stage(class_name, config, force=False): + """ + Execute stage specified by `class_name` argument. + :param class_name: fully qualified stage class name + :param config: config.Configuration instance + """ + + # Create stage object + cls = importutils.import_class(class_name) + assert issubclass(cls, Stage) + stage = cls() + + # Execute dependency stages + for dependency in stage.dependencies: + execute_stage(dependency, config) + + # Check if there is data from this stage in local DB + new_signature = stage.signature(config) + old_signature = None + need_invalidate = False + need_execute = False + with local_db.Transaction() as tx: + row = tx.query_one('SELECT signature FROM stages WHERE stage=:stage', + stage=class_name) + if row is None: + need_execute = True + else: + old_signature = row['signature'].data + need_invalidate = (old_signature != new_signature) + + # Run invalidate and execute if needed + with local_db.Transaction() as tx: + if need_invalidate or force: + stage.invalidate(old_signature, new_signature, force=force) + tx.execute('DELETE FROM stages WHERE stage=:stage', + stage=class_name) + need_execute = True + if need_execute: + stage.execute(config) + tx.execute('INSERT INTO stages VALUES (:stage, :signature)', + stage=class_name, + signature=local_db.Json(new_signature)) diff --git a/cloudferrylib/utils/bases.py b/cloudferrylib/utils/bases.py new file mode 100644 index 00000000..e87b5c13 --- /dev/null +++ b/cloudferrylib/utils/bases.py @@ -0,0 +1,98 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import sys + + +def sorted_field_names(obj): + """ + Returns alphabetically sorted list of public object field names (i.e. their + names don't start with '_') + """ + + return sorted( + f for f in dir(obj) + if not f.startswith('_') and not hasattr(getattr(obj, f), '__call__')) + + +def compute_hash(obj): + """ + Hash function that is able to compute hashes for lists and dictionaries. + """ + + if isinstance(obj, dict): + return hash_iterable(sorted(obj.items())) + elif hasattr(obj, '__iter__'): + return hash_iterable(obj) + else: + return hash(obj) + + +def hash_iterable(iterable): + """ + Compute hash for some iterable value + """ + value = hash(iterable.__class__) + for item in iterable: + value = ((value * 1000003) & sys.maxint) ^ compute_hash(item) + return value + + +class Hashable(object): + """ + Mixin class that make objects hashable based on their public fields (i.e. + which name don't start with '_') + """ + + def __eq__(self, other): + if other.__class__ != self.__class__: + return False + for field in sorted_field_names(self): + if getattr(self, field) != getattr(other, field, None): + return False + return True + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return compute_hash(getattr(self, f) for f in sorted_field_names(self)) + + +class Representable(object): + """ + Mixin class that implement __repr__ method that will show all field values + that are not None. + """ + + def __repr__(self): + cls = self.__class__ + return '<{module}.{cls} {fields}>'.format( + module=cls.__module__, + cls=cls.__name__, + fields=' '.join('{0}:{1}'.format(f, repr(getattr(self, f))) + for f in sorted_field_names(self) + if getattr(self, f) is not None)) + + +class ConstructableFromDict(object): + """ + Mixin class with __init__ method that just assign values from dictionary + to object attributes with names identical to keys from dictionary. + """ + + def __init__(self, data): + assert isinstance(data, collections.Mapping) + for name, value in data.items(): + setattr(self, name, value) diff --git a/cloudferrylib/utils/local_db.py b/cloudferrylib/utils/local_db.py new file mode 100644 index 00000000..e27a445e --- /dev/null +++ b/cloudferrylib/utils/local_db.py @@ -0,0 +1,187 @@ +# Copyright (c) 2014 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and# +# limitations under the License. +import contextlib +import json +import logging +import os +import random +import sqlite3 +import threading +import time + +LOG = logging.getLogger(__name__) +SQLITE3_DATABASE_FILE = os.environ.get('CF_LOCAL_DB', 'migration_data.db') +_execute_once_statements = [] +_executed_statements = set() +_execute_once_mutex = threading.Lock() + + +class Transaction(object): + _tls = threading.local() + _tls.top_level = None + _tls.depth = 0 + + def __init__(self): + self._conn = None + self._cursor = None + self._name = self._generate_name() + self._depth = None + + def __enter__(self): + self._depth = self._tls.depth + self._tls.depth += 1 + self._initialize() + LOG.debug('Transaction started [depth=%d,name=%s]', + self._depth, self._name) + self._do_begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + exc_info = (exc_type, exc_val, exc_tb) + try: + if exc_info == (None, None, None): + self._do_commit() + else: + self._do_rollback() + LOG.debug('Transaction rollback because of exception', + exc_info=(exc_type, exc_val, exc_tb)) + finally: + self._cursor.close() + if self._depth == 0: + self._conn.close() + self._tls.depth -= 1 + LOG.debug('Transaction completed [depth=%d,name=%s]', + self._depth, self._name) + + def _initialize(self): + # pylint: disable=protected-access + if self._depth == 0: + self._tls.top_level = self + self._conn = sqlite3.connect( + SQLITE3_DATABASE_FILE, + detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) + self._conn.row_factory = sqlite3.Row + self._conn.isolation_level = None + execute_saved_statements(self._conn) + else: + self._conn = self._tls.top_level._conn + self._cursor = self._conn.cursor() + + def _do_begin(self): + if self._depth == 0: + self.execute('BEGIN EXCLUSIVE') + else: + self.execute('SAVEPOINT tx_{}'.format(self._name)) + + def _do_commit(self): + if self._depth == 0: + self.execute('COMMIT') + else: + self.execute('RELEASE tx_{}'.format(self._name)) + + def _do_rollback(self): + if self._depth == 0: + self.execute('ROLLBACK') + else: + self.execute('ROLLBACK TO SAVEPOINT tx_{}'.format(self._name)) + + @staticmethod + def _generate_name(): + rnd = random.Random(time.time()) + return '{:032x}'.format(rnd.getrandbits(8 * 16)) + + def execute(self, sql, **kwargs): + """ + Execute SQL statement without returning result. + """ + LOG.debug('SQL execute [%s]: %s %r', self._name, sql, kwargs) + self._cursor.execute(sql, kwargs) + + def query(self, sql, **kwargs): + """ + Execute SQL sql query and return all rows. Any values in SQL query + of :name format will be replaced by values passed as kwargs. + """ + LOG.debug('SQL query [%s]: %s %r', self._name, sql, kwargs) + self._cursor.execute(sql, kwargs) + return self._cursor.fetchall() + + def query_one(self, sql, **kwargs): + """ + Execute SQL sql query and return one rows. Any values in SQL query + of :name format will be replaced by values passed as kwargs. + It is error if query returns more than one row. + """ + LOG.debug('SQL query one [%s]: %s %r', self._name, sql, kwargs) + self._cursor.execute(sql, kwargs) + assert self._cursor.rowcount <= 1 + return self._cursor.fetchone() + + +class Json(object): + """ + Objects of this class become JSON when converted to string or unicode + """ + + def __init__(self, data): + self.data = data + + def __repr__(self): + try: + return json.dumps(self.data, indent=2, sort_keys=True) + except TypeError: + return '' + + @classmethod + def adapt(cls, obj): + assert isinstance(obj, cls) + return json.dumps(obj.data, sort_keys=True) + + @classmethod + def convert(cls, value): + return Json(json.loads(value)) + + +def execute_once(sql, **kwargs): + """ + This function register SQL statement that will be executed only once during + program lifetime. The idea behind this function is to execute table + creation SQL only once. + :param sql: SQL statement string + :param kwargs: SQL statement arguments + """ + with _execute_once_mutex: + _execute_once_statements.append((sql.strip(), kwargs)) + + +def execute_saved_statements(conn, force=False): + """ + Execute statements registered with execute_once. + :param conn: SQLite3 connection + :param force: if force is set to True, then all previously executed SQL + statements will be ignored and will executed again. + This parameter is introduced for testing. + :return: + """ + with _execute_once_mutex, contextlib.closing(conn.cursor()) as cursor: + for sql, kwargs in _execute_once_statements: + key = (sql, tuple(sorted(kwargs.items()))) + if key in _executed_statements or force: + LOG.debug('SQL execute once: %s %r', sql, kwargs) + cursor.execute(sql, kwargs) + _executed_statements.add(key) + + +sqlite3.register_adapter(Json, Json.adapt) +sqlite3.register_converter('json', Json.convert) diff --git a/cloudferrylib/utils/query.py b/cloudferrylib/utils/query.py new file mode 100644 index 00000000..ebc97e3b --- /dev/null +++ b/cloudferrylib/utils/query.py @@ -0,0 +1,131 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jmespath +import jmespath.exceptions + +from cloudferrylib.os.discovery import model + + +class DictSubQuery(object): + """ + Simplified query that use JMESPath queries as dictionary keys to get value + from tested object and then compare it to list of objects specified in + value. + + Example: + + objects: + images: + - tenant.name: demo + container_format: bare + + This query will select all images with owner named 'demo' and container + format 'bare' + + Adding !in front of query will negate match result. For example: + + objects: + vms: + - !tenant.name: rally_test + + This query will select all VMs in cloud except for rally_test tenant + """ + + def __init__(self, pattern): + assert isinstance(pattern, dict) + self.pattern = [self._compile_query(k, v) for k, v in pattern.items()] + + def search(self, values): + """ + Return subset of values that match query parameters. + :param values: list of objects that have get method + :return: list of objects that matched query + """ + return [v for v in values if self._matches(v)] + + def _matches(self, value): + return all(match(value) for match in self.pattern) + + @staticmethod + def _compile_query(key, expected): + assert isinstance(key, basestring) + + negative = False + if key.startswith('!'): + negative = True + key = key[1:] + try: + query = jmespath.compile(key) + + def match(value): + return negative ^ (query.search(value) in expected) + return match + except jmespath.exceptions.ParseError as ex: + raise AssertionError( + 'Failed to compile "{0}": {1}'.format(key, str(ex))) + + +class Query(object): + """ + Parsed and compiled query using which it is possible to filter instances of + model.Model class stored in database. + """ + + def __init__(self, query): + """ + Accept dict as specified in configuration, compile all the JMESPath + queries, and store it as internal immutable state. + :param query: query dictionary + """ + + assert isinstance(query, dict) + + self.queries = {} + for type_name, subqueries in query.items(): + cls = model.get_model(type_name) + for subquery in subqueries: + if isinstance(subquery, basestring): + subquery = jmespath.compile(subquery) + else: + subquery = DictSubQuery(subquery) + cls_queries = self.queries.setdefault(cls, []) + cls_queries.append(subquery) + + def search(self, session, cloud=None, cls=None): + """ + Search through list of objects from database of class that specified + in cls argument (if cls is none, then all classes are considered) that + are collected from cloud specified in cloud argument (if cloud is none, + then all clouds are considered) for objects matching this query. + + :param session: active model.Session instance + :param cloud: cloud name + :param cls: class object + :return: list of objects that match query + """ + result = set() + if cls is None: + for cls, queries in self.queries.items(): + objects = session.list(cls, cloud) + for query in queries: + result.update(query.search(objects)) + return result + else: + queries = self.queries.get(cls) + if queries is None: + return [] + objects = session.list(cls, cloud) + for query in queries: + result.update(query.search(objects)) + return result diff --git a/cloudferrylib/os/context.py b/cloudferrylib/utils/remote.py similarity index 50% rename from cloudferrylib/os/context.py rename to cloudferrylib/utils/remote.py index 548c1786..d4315fae 100644 --- a/cloudferrylib/os/context.py +++ b/cloudferrylib/utils/remote.py @@ -18,52 +18,29 @@ from fabric import state as fab_state from fabric import network as fab_network -from cloudferrylib.os import clients -from cloudferrylib.utils import utils - LOG = logging.getLogger(__name__) -MODEL_LIST = [ - 'cloudferrylib.os.discovery.keystone.Tenant', - 'cloudferrylib.os.discovery.glance.Image', - 'cloudferrylib.os.discovery.cinder.Volume', - 'cloudferrylib.os.discovery.nova.Server', -] -class SshSettings(object): - def __init__(self, username, sudo_password=None, gateway=None, - connection_attempts=1, cipher=None, key_file=None): +class RemoteExecutor(object): + """ + Remote executor with minimal number of dependencies. + """ + + def __init__(self, hostname, username, sudo_password=None, gateway=None, + connection_attempts=1, cipher=None, key_file=None, + ignore_errors=False): self.username = username self.sudo_password = sudo_password self.gateway = gateway self.connection_attempts = connection_attempts self.cipher = cipher self.key_file = key_file - - -class Context(object): - def __init__(self, clouds=None): - self.clouds = {} - for name, cloud in (clouds or {}).items(): - credential = clients.Credential(**cloud['credential']) - scope = clients.Scope(**cloud['scope']) - ssh_settings = SshSettings(**cloud['ssh']) - self.clouds[name] = OpenstackCloud(name, credential, scope, - ssh_settings) - - def get_cloud(self, name): - return self.clouds[name] - - -class RemoteExecutor(object): - def __init__(self, ssh_settings, hostname, ignore_errors): - self.ssh_settings = ssh_settings self.hostname = hostname self.ignore_errors = ignore_errors def sudo(self, cmd, **kwargs): formatted_cmd = cmd.format(**kwargs) - if self.ssh_settings.username != 'root': + if self.username != 'root': return self._run(fab_api.sudo, formatted_cmd) else: return self._run(fab_api.run, formatted_cmd) @@ -72,6 +49,7 @@ def run(self, cmd, **kwargs): return self._run(fab_api.run, cmd.format(**kwargs)) def _run(self, run_function, command): + # TODO: rewrite using plain paramiko for multithreading support LOG.debug('[%s] running command "%s"', self.hostname, command) abort_exception = None if self.ignore_errors: @@ -80,13 +58,13 @@ def _run(self, run_function, command): fab_api.hide('warnings', 'running', 'stdout', 'stderr'), warn_only=self.ignore_errors, host_string=self.hostname, - user=self.ssh_settings.username, - password=self.ssh_settings.sudo_password, + user=self.username, + password=self.sudo_password, abort_exception=abort_exception, reject_unkown_hosts=False, combine_stderr=False, - gateway=self.ssh_settings.gateway, - connection_attempts=self.ssh_settings.connection_attempts): + gateway=self.gateway, + connection_attempts=self.connection_attempts): return run_function(command) def scp(self, src_path, host, dst_path, username=None, flags=None): @@ -97,8 +75,8 @@ def scp(self, src_path, host, dst_path, username=None, flags=None): command += ' ' + flags # Add cipher option - if self.ssh_settings.cipher is not None: - command += ' -c ' + self.ssh_settings.cipher + if self.cipher is not None: + command += ' -c ' + self.cipher # Put source path command += ' \'{0}\''.format(src_path) @@ -126,48 +104,11 @@ def tmpdir(self, prefix='cloudferry'): if path is not None: self.sudo('rm -rf {path}', path=path) - -class OpenstackCloud(object): - def __init__(self, name, credential, scope, ssh_settings, discover=None): - if discover is None: - discover = MODEL_LIST - self.name = name - self.credential = credential - self.scope = scope - self.ssh_settings = ssh_settings - self.discover = discover - - def image_client(self, scope=None): - return clients.image_client(self.credential, scope or self.scope) - - def identity_client(self, scope=None): - return clients.identity_client(self.credential, scope or self.scope) - - def volume_client(self, scope=None): - return clients.volume_client(self.credential, scope or self.scope) - - def compute_client(self, scope=None): - return clients.compute_client(self.credential, scope or self.scope) - - @contextlib.contextmanager - def remote_executor(self, hostname, key_file=None, ignore_errors=False): - key_files = [] - if self.ssh_settings.key_file is not None: - key_files.append(self.ssh_settings.key_file) - if key_file is not None: - key_files.append(key_file) - if key_files: - utils.ensure_ssh_key_added(key_files) - try: - yield RemoteExecutor(self.ssh_settings, hostname, ignore_errors) - finally: - _close_connection(hostname) - - -def _close_connection(hostname): - for key, conn in fab_state.connections.items(): - _, conn_hostname = fab_network.normalize(key, True) - if conn_hostname == hostname: - conn.close() - del fab_state.connections[key] - break + @staticmethod + def close_connection(hostname): + for key, conn in fab_state.connections.items(): + _, conn_hostname = fab_network.normalize(key, True) + if conn_hostname == hostname: + conn.close() + del fab_state.connections[key] + break diff --git a/cloudferrylib/utils/utils.py b/cloudferrylib/utils/utils.py index 128ecc09..ec1f17de 100644 --- a/cloudferrylib/utils/utils.py +++ b/cloudferrylib/utils/utils.py @@ -52,6 +52,7 @@ DISK = "disk" DISK_EPHEM = "disk.local" +DISK_CONFIG = "disk.config" LEN_UUID_INSTANCE = 36 HOST_SRC = 'host_src' diff --git a/discover.yaml b/discover.yaml index ebb3a57d..76031306 100644 --- a/discover.yaml +++ b/discover.yaml @@ -1,27 +1,53 @@ -context: - clouds: - grizzly_ewr2: - credential: - auth_url: https://keystone.example.com/v2.0/ - username: foobar - password: foobar - scope: - project_id: 00000000000000000000000000000000 - ssh: - username: foobar - sudo_password: foobar - connection_attempts: 3 +clouds: + grizzly: + credential: + auth_url: https://keystone.example.com/v2.0/ + username: admin + password: admin + region_name: grizzly + scope: + project_id: 00000000000000000000000000000000 + ssh: + username: foobar + sudo_password: foobar + connection_attempts: 3 + + liberty: + credential: + auth_url: https://keystone.example.com/v2.0/ + username: admin + password: admin + region_name: liberty + scope: + project_id: 00000000000000000000000000000000 + ssh: + username: foobar + sudo_password: foobar + connection_attempts: 3 # All objects are discovered by default, but list of discovered objects can # be specified using discover parameter to cloud -# discover: -# - cloudferrylib.os.discovery.keystone.Tenant -# - cloudferrylib.os.discovery.glance.Image -# - cloudferrylib.os.discovery.cinder.Volume -# - cloudferrylib.os.discovery.nova.Server +# discover: +# - cloudferrylib.os.discovery.keystone.Tenant +# - cloudferrylib.os.discovery.glance.Image +# - cloudferrylib.os.discovery.cinder.Volume +# - cloudferrylib.os.discovery.nova.Server # Import legacy configuration -# grizzly_src: -# legacy: configuration.grizzly-juno.ini:src -# juno_dst: -# legacy: configuration.ini:dst +# grizzly_src: +# legacy: configuration.grizzly-juno.ini:src +# juno_dst: +# legacy: configuration.ini:dst + +migrations: + grizzly_to_liberty: + source: grizzly + destination: liberty + objects: + vms: + - tenant.name: demo # Include VMs owned by tenant named "demo" into migration + images: + - tenant.name: demo # Include images owned by tenant named "demo" into migration + - is_public: True # Also include any public images (no matter owned by which tenant) into migration + volumes: + - tenant.name: demo # Include volumes owned by tenant named "demo" into migration diff --git a/fabfile.py b/fabfile.py index 80581adf..e03c571f 100644 --- a/fabfile.py +++ b/fabfile.py @@ -18,12 +18,12 @@ from fabric.api import task, env import yaml -from oslo_utils import importutils import oslo_config.cfg import oslo_config.types import cfglib -from cloudferrylib.os import context +from cloudferrylib import config +from cloudferrylib import stage from cloudferrylib.os.estimation import procedures from cloudferrylib.scheduler.namespace import Namespace from cloudferrylib.scheduler.scheduler import Scheduler @@ -200,27 +200,31 @@ def discover(config_path, debug=False): """ :config_name - name of config yaml-file, example 'config.yaml' """ - config = load_yaml_config(config_path, debug) - ctx = context.Context(**config['context']) - for cloud_name, cloud in ctx.clouds.items(): - for fq_class_name in cloud.discover: - cls = importutils.import_class(fq_class_name) - LOG.info('Starting discover %s objects in %s cloud', - cls.__name__, cloud_name) - cls.discover(cloud) - LOG.info('Done discovering %s objects in %s cloud', - cls.__name__, cloud_name) + cfg = config.load(load_yaml_config(config_path, debug)) + stage.execute_stage('cloudferrylib.os.discovery.stages.DiscoverStage', cfg, + force=True) @task -def estimate_migration(source, tenant=None): - procedures.estimate_copy(source, tenant) - procedures.show_largest_servers(10, source, tenant) - procedures.show_largest_unused_resources(10, source, tenant) +def estimate_migration(config_path, migration, debug=False): + cfg = config.load(load_yaml_config(config_path, debug)) + if migration not in cfg.migrations: + print 'No such migration:', migration + print '\nPlease choose one of this:' + for name in sorted(cfg.migrations.keys()): + print ' -', name + return -1 + + stage.execute_stage('cloudferrylib.os.discovery.stages.DiscoverStage', cfg) + procedures.estimate_copy(cfg, migration) + procedures.show_largest_servers(cfg, 10, migration) @task -def show_unused_resources(cloud, count=100, tenant=None): +def show_unused_resources(config_path, cloud, count=100, tenant=None, + debug=False): + cfg = config.load(load_yaml_config(config_path, debug)) + stage.execute_stage('cloudferrylib.os.discovery.stages.DiscoverStage', cfg) procedures.show_largest_unused_resources(int(count), cloud, tenant) @@ -264,16 +268,16 @@ def import_legacy(cloud, cfg): prev_legacy_config_path = None with open(yaml_path, 'r') as config_file: - config = yaml.load(config_file) - clouds = config.setdefault('context', {}).setdefault('clouds', {}) - for name, value in clouds.items(): + cfg = yaml.load(config_file) + clouds = cfg.setdefault('clouds', {}) + for value in clouds.values(): if 'legacy' not in value: continue legacy_config_path, section = value.pop('legacy').split(':') if prev_legacy_config_path != legacy_config_path: init(legacy_config_path, debug) import_legacy(value, getattr(cfglib.CONF, section)) - return config + return cfg if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 8763c472..9bbdb968 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,8 @@ python-neutronclient==2.3.4 python-novaclient==2.20.0 python-swiftclient==2.3.1 pyyaml -pywbem +pywbem==0.7.0 redis sqlalchemy marshmallow==2.4.2 +jmespath==0.9.0 diff --git a/tests/cloudferrylib/os/discovery/test_model.py b/tests/cloudferrylib/os/discovery/test_model.py index 3b3017d1..272c9930 100644 --- a/tests/cloudferrylib/os/discovery/test_model.py +++ b/tests/cloudferrylib/os/discovery/test_model.py @@ -12,18 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import mock -import sqlite3 import uuid from cloudferrylib.os.discovery import model - -from tests import test - - -class UnclosableConnection(sqlite3.Connection): - def close(self, *args, **kwargs): - if kwargs.get('i_mean_it'): - super(UnclosableConnection, self).close() +from tests.cloudferrylib.utils import test_local_db class ExampleReferenced(model.Model): @@ -141,17 +133,9 @@ def generate_clean_data(cls): } -class ModelTest(test.TestCase): +class ModelTestCase(test_local_db.DatabaseMockingTestCase): def setUp(self): - super(ModelTest, self).setUp() - - connection = sqlite3.connect(':memory:', factory=UnclosableConnection) - self.addCleanup(connection.close, i_mean_it=True) - - sqlite3_connect_patcher = mock.patch('sqlite3.connect') - self.addCleanup(sqlite3_connect_patcher.stop) - connect_mock = sqlite3_connect_patcher.start() - connect_mock.return_value = connection + super(ModelTestCase, self).setUp() self.cloud = mock.MagicMock() self.cloud.name = 'test_cloud' @@ -159,9 +143,10 @@ def setUp(self): self.cloud2 = mock.MagicMock() self.cloud2.name = 'test_cloud2' - def _validate_example_obj(self, object_id, obj, validate_refs=True): + def _validate_example_obj(self, object_id, obj, validate_refs=True, + bar_value='some non-random string'): self.assertEqual(object_id, obj.object_id) - self.assertEqual('some non-random string', obj.bar) + self.assertEqual(bar_value, obj.bar) self.assertEqual('other non-random string', obj.baz.foo) if validate_refs: self.assertEqual(1337, obj.ref.qux) @@ -233,27 +218,27 @@ def test_store_retrieve(self): data = Example.generate_cloud_data() orig_obj = Example.load_from_cloud(self.cloud, data) object_id = orig_obj.object_id - with model.Transaction() as tx: - tx.store(orig_obj) + with model.Session() as session: + session.store(orig_obj) # Validate retrieve working before commit self._validate_example_obj( - object_id, tx.retrieve(Example, object_id)) - with model.Transaction() as tx: + object_id, session.retrieve(Example, object_id)) + with model.Session() as session: # Validate retrieve working after commit self._validate_example_obj( - object_id, tx.retrieve(Example, object_id)) + object_id, session.retrieve(Example, object_id)) def test_store_list(self): data = Example.generate_cloud_data() orig_obj = Example.load_from_cloud(self.cloud, data) object_id = orig_obj.object_id - with model.Transaction() as tx: - tx.store(orig_obj) + with model.Session() as session: + session.store(orig_obj) # Validate retrieve working before commit - self._validate_example_obj(object_id, tx.list(Example)[0]) - with model.Transaction() as tx: + self._validate_example_obj(object_id, session.list(Example)[0]) + with model.Session() as session: # Validate retrieve working after commit - self._validate_example_obj(object_id, tx.list(Example)[0]) + self._validate_example_obj(object_id, session.list(Example)[0]) def test_store_list_cloud(self): data = Example.generate_cloud_data() @@ -261,35 +246,35 @@ def test_store_list_cloud(self): object1_id = orig_obj1.object_id orig_obj2 = Example.load_from_cloud(self.cloud2, data) object2_id = orig_obj2.object_id - with model.Transaction() as tx: - tx.store(orig_obj1) - tx.store(orig_obj2) + with model.Session() as session: + session.store(orig_obj1) + session.store(orig_obj2) # Validate retrieve working before commit self._validate_example_obj(object1_id, - tx.list(Example, 'test_cloud')[0]) + session.list(Example, 'test_cloud')[0]) self._validate_example_obj(object2_id, - tx.list(Example, 'test_cloud2')[0]) + session.list(Example, 'test_cloud2')[0]) # Validate retrieve working after commit - with model.Transaction() as tx: + with model.Session() as session: self._validate_example_obj(object1_id, - tx.list(Example, 'test_cloud')[0]) - with model.Transaction() as tx: + session.list(Example, 'test_cloud')[0]) + with model.Session() as session: self._validate_example_obj(object2_id, - tx.list(Example, 'test_cloud2')[0]) + session.list(Example, 'test_cloud2')[0]) def test_load_store(self): data = Example.generate_cloud_data() orig_obj = Example.load_from_cloud(self.cloud, data) object_id = orig_obj.object_id - with model.Transaction() as tx: - tx.store(orig_obj) - with model.Transaction() as tx: - obj = tx.retrieve(Example, object_id) + with model.Session() as session: + session.store(orig_obj) + with model.Session() as session: + obj = session.retrieve(Example, object_id) self._validate_example_obj(object_id, obj) obj.baz.foo = 'changed' obj.bar = 'changed too' - with model.Transaction() as tx: - loaded_obj = tx.retrieve(Example, object_id) + with model.Session() as session: + loaded_obj = session.retrieve(Example, object_id) self.assertEqual('changed', loaded_obj.baz.foo) self.assertEqual('changed too', loaded_obj.bar) @@ -310,10 +295,12 @@ class Schema(model.Schema): self.assertEqual('foo', many.many[0].foo) self.assertEqual('bar', many.many[1].foo) self.assertEqual('baz', many.many[2].foo) - with model.Transaction() as tx: - tx.store(many) - with model.Transaction() as tx: - obj = tx.retrieve(ExampleMany, model.ObjectId('foo', 'test_cloud')) + with model.Session() as session: + session.store(many) + + with model.Session() as session: + obj = session.retrieve( + ExampleMany, model.ObjectId('foo', 'test_cloud')) self.assertEqual('foo', obj.many[0].foo) self.assertEqual('bar', obj.many[1].foo) self.assertEqual('baz', obj.many[2].foo) @@ -330,3 +317,47 @@ class Schema(model.Schema): 'ref': str('foo-bar-baz'), }) self.assertIs(Example, obj.ref.get_class()) + + def test_nested_sessions(self): + data = Example.generate_cloud_data() + orig_obj1 = Example.load_from_cloud(self.cloud, data) + object1_id = orig_obj1.object_id + orig_obj2 = Example.load_from_cloud(self.cloud2, data) + object2_id = orig_obj2.object_id + + with model.Session() as s1: + s1.store(orig_obj1) + with model.Session() as s2: + s2.store(orig_obj2) + self._validate_example_obj( + object1_id, s2.retrieve(Example, object1_id)) + self._validate_example_obj( + object2_id, s2.retrieve(Example, object2_id)) + with model.Session() as s: + self._validate_example_obj( + object1_id, s.retrieve(Example, object1_id)) + self._validate_example_obj( + object2_id, s2.retrieve(Example, object2_id)) + + def test_nested_sessions_save_updates_after_nested(self): + data = Example.generate_cloud_data() + orig_obj1 = Example.load_from_cloud(self.cloud, data) + object1_id = orig_obj1.object_id + orig_obj2 = Example.load_from_cloud(self.cloud2, data) + object2_id = orig_obj2.object_id + + with model.Session() as s1: + s1.store(orig_obj1) + with model.Session() as s2: + s2.store(orig_obj2) + self._validate_example_obj( + object1_id, s2.retrieve(Example, object1_id)) + self._validate_example_obj( + object2_id, s2.retrieve(Example, object2_id)) + orig_obj1.bar = 'some other non-random string' + with model.Session() as s: + self._validate_example_obj( + object1_id, s.retrieve(Example, object1_id), + bar_value='some other non-random string') + self._validate_example_obj( + object2_id, s2.retrieve(Example, object2_id)) diff --git a/tests/cloudferrylib/os/image/test_glance_image.py b/tests/cloudferrylib/os/image/test_glance_image.py index ffa0ff8b..31929589 100644 --- a/tests/cloudferrylib/os/image/test_glance_image.py +++ b/tests/cloudferrylib/os/image/test_glance_image.py @@ -110,6 +110,7 @@ def setUp(self): 'disk_format': 'qcow2', 'id': 'fake_image_id_1', 'is_public': True, + 'location': None, 'owner': 'fake_tenant_id', 'owner_name': 'fake_tenant_name', 'name': 'fake_image_name_1', @@ -118,6 +119,10 @@ def setUp(self): 'resource': self.image_mock, 'members': {}, 'properties': {}, + 'copy_from': None, + 'min_disk': None, + 'min_ram': None, + 'store': None, 'deleted': False}, 'meta': {'img_loc': None}}}, 'tags': {}, @@ -167,8 +172,7 @@ def test_delete_image(self): fake_image_id) def test_get_image_by_id(self): - fake_images = [self.fake_image_1, self.fake_image_2] - self.glance_mock_client().images.list.return_value = fake_images + self.glance_mock_client().images.get.return_value = self.fake_image_1 self.assertEquals(self.fake_image_1, self.glance_image.get_image_by_id('fake_image_id_1')) @@ -192,8 +196,7 @@ def test_get_image(self): self.glance_image.get_image('fake_image_name_2')) def test_get_image_status(self): - fake_images = [self.fake_image_1, self.fake_image_2] - self.glance_mock_client().images.list.return_value = fake_images + self.glance_mock_client().images.get.return_value = self.fake_image_1 self.assertEquals(self.fake_image_1.status, self.glance_image.get_image_status( @@ -207,8 +210,7 @@ def test_get_ref_image(self): self.glance_image.get_ref_image('fake_image_id_1')) def test_get_image_checksum(self): - fake_images = [self.fake_image_1, self.fake_image_2] - self.glance_mock_client().images.list.return_value = fake_images + self.glance_mock_client().images.get.return_value = self.fake_image_1 self.assertEquals(self.fake_image_1.checksum, self.glance_image.get_image_checksum( diff --git a/tests/cloudferrylib/test_stage.py b/tests/cloudferrylib/test_stage.py new file mode 100644 index 00000000..be291ab0 --- /dev/null +++ b/tests/cloudferrylib/test_stage.py @@ -0,0 +1,96 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cloudferrylib import stage + +from tests.cloudferrylib.utils import test_local_db + +import mock + +call_checker = None + + +def fqname(cls): + return cls.__module__ + '.' + cls.__name__ + + +class TestStage(stage.Stage): + def __init__(self): + self.invalidated = False + + def signature(self, config): + return config + + def execute(self, config): + self._execute(config, self.invalidated) + + def invalidate(self, old_signature, new_signature, force=False): + self._invalidate(old_signature, new_signature) + self.invalidated = True + + def _execute(self, config, invalidated): + pass + + def _invalidate(self, old_signature, new_signature): + pass + + +class StageOne(TestStage): + pass + + +class StageTwo(TestStage): + dependencies = [ + fqname(StageOne) + ] + + +class StageTestCase(test_local_db.DatabaseMockingTestCase): + def setUp(self): + super(StageTestCase, self).setUp() + self.config1 = {'marker': 1} + self.config2 = {'marker': 2} + + @mock.patch.object(StageOne, '_execute') + def test_dependencies_execute(self, execute): + stage.execute_stage(fqname(StageOne), self.config1) + execute.assert_called_once_with(self.config1, False) + + @mock.patch.object(StageOne, '_execute') + @mock.patch.object(StageTwo, '_execute') + def test_dependencies_execute_once(self, execute_two, execute_one): + stage.execute_stage(fqname(StageOne), self.config1) + stage.execute_stage(fqname(StageTwo), self.config1) + execute_one.assert_called_once_with(self.config1, False) + execute_two.assert_called_once_with(self.config1, False) + + @mock.patch.object(StageOne, '_execute') + @mock.patch.object(StageTwo, '_execute') + def test_dependencies_execute_deps(self, execute_two, execute_one): + stage.execute_stage(fqname(StageTwo), self.config1) + execute_one.assert_called_once_with(self.config1, False) + execute_two.assert_called_once_with(self.config1, False) + + @mock.patch.object(StageOne, '_invalidate') + @mock.patch.object(StageOne, '_execute') + @mock.patch.object(StageTwo, '_execute') + def test_invalidate_dependencies_on_configuration_change( + self, execute_two, execute_one, invalidate_one): + stage.execute_stage(fqname(StageOne), self.config1) + stage.execute_stage(fqname(StageTwo), self.config2) + execute_one.assert_has_calls([ + mock.call(self.config1, False), + mock.call(self.config2, True), + ]) + execute_two.assert_called_once_with(self.config2, False) + invalidate_one.assert_called_once_with(self.config1, self.config2) diff --git a/tests/cloudferrylib/utils/test_local_db.py b/tests/cloudferrylib/utils/test_local_db.py new file mode 100644 index 00000000..4e396308 --- /dev/null +++ b/tests/cloudferrylib/utils/test_local_db.py @@ -0,0 +1,135 @@ +# Copyright (c) 2014 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and# +# limitations under the License. +import sqlite3 + +import mock + +from tests import test +from cloudferrylib.utils import local_db + + +class UnclosableConnection(sqlite3.Connection): + def close(self, *args, **kwargs): + # pylint: disable=unused-argument + if kwargs.get('i_mean_it'): + super(UnclosableConnection, self).close() + + +class DatabaseMockingTestCase(test.TestCase): + def setUp(self): + super(DatabaseMockingTestCase, self).setUp() + + connection = sqlite3.connect( + ':memory:', factory=UnclosableConnection, + detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) + self.addCleanup(connection.close, i_mean_it=True) + + sqlite3_connect_patcher = mock.patch('sqlite3.connect') + self.addCleanup(sqlite3_connect_patcher.stop) + connect_mock = sqlite3_connect_patcher.start() + connect_mock.return_value = connection + + local_db.execute_saved_statements(connection, force=True) + + +class LocalDbTestCase(DatabaseMockingTestCase): + def setUp(self): + super(LocalDbTestCase, self).setUp() + with local_db.Transaction() as tx: + tx.execute("""CREATE TABLE IF NOT EXISTS tests ( + key TEXT, + value TEXT, + PRIMARY KEY (key) + )""") + + def test_write_read_same_tx(self): + with local_db.Transaction() as tx: + tx.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('bar', row['value']) + + def test_write_read_different_tx(self): + with local_db.Transaction() as tx: + tx.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('bar', row['value']) + + def test_write_write_read(self): + with local_db.Transaction() as tx: + tx.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + with local_db.Transaction() as tx: + tx.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='baz') + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('baz', row['value']) + + def test_write_write_rollback_read_first_value(self): + with local_db.Transaction() as tx: + tx.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + try: + with local_db.Transaction() as tx: + tx.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='baz') + raise RuntimeError() + except RuntimeError: + pass + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('bar', row['value']) + + def test_nested_tx(self): + with local_db.Transaction() as tx1: + tx1.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + with local_db.Transaction() as tx2: + tx2.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='baz') + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('baz', row['value']) + + def test_nested_tx_rollback_inner(self): + with local_db.Transaction() as tx1: + tx1.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + try: + with local_db.Transaction() as tx2: + tx2.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='baz') + raise RuntimeError() + except RuntimeError: + pass + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('bar', row['value']) + + def test_nested_tx_rollback_outer(self): + # Prepare state + with local_db.Transaction() as tx: + tx.execute('INSERT INTO tests VALUES (:k, :v)', k='foo', v='bar') + + # Run outer rollback from inner tx + try: + with local_db.Transaction() as tx1: + tx1.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='baz') + with local_db.Transaction() as tx2: + tx2.execute('UPDATE tests SET value=:v WHERE key=:k', + k='foo', v='qux') + raise RuntimeError() + except RuntimeError: + pass + with local_db.Transaction() as tx: + row = tx.query_one('SELECT value FROM tests WHERE key=:k', k='foo') + self.assertEqual('bar', row['value']) diff --git a/tests/cloudferrylib/utils/test_query.py b/tests/cloudferrylib/utils/test_query.py new file mode 100644 index 00000000..349aae44 --- /dev/null +++ b/tests/cloudferrylib/utils/test_query.py @@ -0,0 +1,133 @@ +# Copyright 2016 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cloudferrylib.os.discovery import model +from cloudferrylib.utils import query +from marshmallow import fields +from tests.cloudferrylib.utils import test_local_db +import mock + + +class TestMode(model.Model): + class Schema(model.Schema): + object_id = model.PrimaryKey('id') + field1 = fields.String() + field2 = fields.String() + +CLASS_FQN = TestMode.__module__ + '.' + TestMode.__name__ + + +class StageTestCase(test_local_db.DatabaseMockingTestCase): + def setUp(self): + super(StageTestCase, self).setUp() + + self.cloud = mock.MagicMock() + self.cloud.name = 'test_cloud' + + self.obj1 = TestMode.load_from_cloud(self.cloud, { + 'id': 'id1', + 'field1': 'a', + 'field2': 'a', + }) + self.obj2 = TestMode.load_from_cloud(self.cloud, { + 'id': 'id2', + 'field1': 'a', + 'field2': 'b', + }) + self.obj3 = TestMode.load_from_cloud(self.cloud, { + 'id': 'id3', + 'field1': 'b', + 'field2': 'a', + }) + self.obj4 = TestMode.load_from_cloud(self.cloud, { + 'id': 'id4', + 'field1': 'b', + 'field2': 'b', + }) + + with model.Session() as s: + s.store(self.obj1) + s.store(self.obj2) + s.store(self.obj3) + s.store(self.obj4) + + def test_simple_query1(self): + q = query.Query({ + CLASS_FQN: [ + { + 'field1': ['a'], + } + ] + }) + with model.Session() as session: + objs = sorted(q.search(session), key=lambda x: x.object_id.id) + self.assertEqual(2, len(objs)) + self.assertEqual(objs[0].object_id.id, 'id1') + self.assertEqual(objs[1].object_id.id, 'id2') + + def test_simple_query2(self): + q = query.Query({ + CLASS_FQN: [ + { + 'field1': ['b'], + 'field2': ['b'], + } + ] + }) + with model.Session() as session: + objs = sorted(q.search(session), key=lambda x: x.object_id.id) + self.assertEqual(1, len(objs)) + self.assertEqual(objs[0].object_id.id, 'id4') + + def test_simple_query3(self): + q = query.Query({ + CLASS_FQN: [ + { + 'field1': ['a'], + }, + { + 'field2': ['b'], + }, + ] + }) + with model.Session() as session: + objs = sorted(q.search(session), key=lambda x: x.object_id.id) + self.assertEqual(3, len(objs)) + self.assertEqual(objs[0].object_id.id, 'id1') + self.assertEqual(objs[1].object_id.id, 'id2') + self.assertEqual(objs[2].object_id.id, 'id4') + + def test_simple_query_negative(self): + q = query.Query({ + CLASS_FQN: [ + { + '!field1': ['b'], + 'field2': ['b'], + } + ] + }) + with model.Session() as session: + objs = sorted(q.search(session), key=lambda x: x.object_id.id) + self.assertEqual(1, len(objs)) + self.assertEqual(objs[0].object_id.id, 'id2') + + def test_jmespath_query(self): + q = query.Query({ + CLASS_FQN: [ + '[? field1 == `b` && field2 == `a` ]' + ] + }) + with model.Session() as session: + objs = sorted(q.search(session), key=lambda x: x.object_id.id) + self.assertEqual(1, len(objs)) + self.assertEqual(objs[0].object_id.id, 'id3')