diff --git a/.gitignore b/.gitignore index a742aa9..a337052 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ target/ db.* TODO* +**.sqlite +**.sqlite-journal +**.env \ No newline at end of file diff --git a/README.md b/README.md index 6327a3b..9037e68 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,15 @@ Following a basic tutorial to demonstrate how to use the ORM. [] ``` +## Linter + +Check code lint: + +```sh +pip install pylint +pylint orm.py +``` + ## Contributing See [CONTRIBUTING](/CONTRIBUTING.md). diff --git a/orm.py b/orm.py index 3800c3f..69ae8ca 100644 --- a/orm.py +++ b/orm.py @@ -16,17 +16,21 @@ def attrs(obj): ''' Return attribute values dictionary for an object ''' - return dict(i for i in vars(obj).items() if i[0][0] is not '_') + return dict(i for i in vars(obj).items() if i[0][0] != '_') -def copy_attrs(obj, remove=[]): +def copy_attrs(obj, remove=None): ''' Copy attribute values for an object ''' + if remove is None: + remove = [] return dict(i for i in attrs(obj).items() if i[0] not in remove) def render_column_definitions(model): ''' Create SQLite column definitions for an entity model ''' - return ['%s %s' % (k, DATA_TYPES[v]) for k, v in attrs(model).items()] + model_attrs = attrs(model).items() + model_attrs = {k: v for k, v in model_attrs if k != 'db'} + return ['%s %s' % (k, DATA_TYPES[v]) for k, v in model_attrs.items()] def render_create_table_stmt(model): @@ -37,17 +41,19 @@ def render_create_table_stmt(model): return sql.format(**params) -class Database(object): +class Database(object): # pylint: disable=R0205 ''' Proxy class to access sqlite3.connect method ''' def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs + self._connection = None self.connected = False - self.Model = type('Model%s' % str(self), (Model,), {'db': self}) + self.Model = type('Model%s' % str(self), (Model,), {'db': self}) # pylint: disable=C0103 @property def connection(self): + ''' Create SQL connection ''' if self.connected: return self._connection self._connection = sqlite3.connect(*self.args, **self.kwargs) @@ -56,22 +62,26 @@ def connection(self): return self._connection def close(self): + ''' Close SQL connection ''' if self.connected: self.connection.close() self.connected = False def commit(self): + ''' Commit SQL changes ''' self.connection.commit() def execute(self, sql, *args): + ''' Execute SQL ''' return self.connection.execute(sql, args) def executescript(self, script): + ''' Execute SQL script ''' self.connection.cursor().executescript(script) self.commit() -class Manager(object): +class Manager(object): # pylint: disable=R0205 ''' Data mapper interface (generic repository) for models ''' def __init__(self, db, model, type_check=True): @@ -83,19 +93,23 @@ def __init__(self, db, model, type_check=True): self.db.executescript(render_create_table_stmt(self.model)) def all(self): + ''' Get all model objects from database ''' result = self.db.execute('SELECT * FROM %s' % self.table_name) return (self.create(**row) for row in result.fetchall()) def create(self, **kwargs): + ''' Create a model object ''' obj = object.__new__(self.model) obj.__dict__ = kwargs return obj def delete(self, obj): + ''' Delete a model object from database ''' sql = 'DELETE from %s WHERE id = ?' self.db.execute(sql % self.table_name, obj.id) def get(self, id): + ''' Get a model object from database by its id ''' sql = 'SELECT * FROM %s WHERE id = ?' % self.table_name result = self.db.execute(sql, id) row = result.fetchone() @@ -105,11 +119,13 @@ def get(self, id): return self.create(**row) def has(self, id): + ''' Check if a model object exists in database by its id ''' sql = 'SELECT id FROM %s WHERE id = ?' % self.table_name result = self.db.execute(sql, id) return True if result.fetchall() else False def save(self, obj): + ''' Save a model object ''' if 'id' in obj.__dict__ and self.has(obj.id): msg = 'Object%s id already registred: %s' % (self.model, obj.id) raise ValueError(msg) @@ -117,12 +133,14 @@ def save(self, obj): self.type_check and self._isvalid(clone) column_names = '%s' % ', '.join(clone.keys()) column_references = '%s' % ', '.join('?' for i in range(len(clone))) - sql = 'INSERT INTO %s (%s) VALUES (%s)' % (self.table_name, column_names, column_references) # noqa + sql = 'INSERT INTO %s (%s) VALUES (%s)' + sql = sql % (self.table_name, column_names, column_references) result = self.db.execute(sql, *clone.values()) obj.id = result.lastrowid return obj def update(self, obj): + ''' Update a model object ''' clone = copy_attrs(obj, remove=['id']) self.type_check and self._isvalid(clone) where_expressions = '= ?, '.join(clone.keys()) + '= ?' @@ -146,22 +164,26 @@ def _isvalid(self, attr_values): raise TypeError(msg % (attr, attr_types[attr], value_type)) -class Model(object): +class Model(object): # pylint: disable=R0205 ''' Abstract entity model with an active record interface ''' db = None def delete(self, type_check=True): + ''' Delete this model object ''' return self.__class__.manager(type_check=type_check).delete(self) def save(self, type_check=True): + ''' Save this model object ''' return self.__class__.manager(type_check=type_check).save(self) def update(self, type_check=True): + ''' Update this model object ''' return self.__class__.manager(type_check=type_check).update(self) @property def public(self): + ''' Return the public model attributes ''' return attrs(self) def __repr__(self): @@ -169,4 +191,5 @@ def __repr__(self): @classmethod def manager(cls, db=None, type_check=True): + ''' Create a database managet ''' return Manager(db if db else cls.db, cls, type_check) diff --git a/tests.py b/tests.py index 2c4b157..f8bb845 100644 --- a/tests.py +++ b/tests.py @@ -6,7 +6,6 @@ class Post(db.Model): - random = float text = str