Skip to content

Commit

Permalink
Merge pull request #14 from fernandojunior/iss13
Browse files Browse the repository at this point in the history
issue #13
  • Loading branch information
fernandojunior authored Oct 10, 2018
2 parents 3a7ee71 + 40b4cea commit 7c79694
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ target/
db.*

TODO*
**.sqlite
**.sqlite-journal
**.env
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ Following a basic tutorial to demonstrate how to use the ORM.
[]
```

## Linter

Check code lint:

```sh
pip install pylint
pylint orm.py
```

## Contributing

See [CONTRIBUTING](/CONTRIBUTING.md).
Expand Down
39 changes: 31 additions & 8 deletions orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@

def attrs(obj):
''' Return attribute values dictionary for an object '''
return dict(i for i in vars(obj).items() if i[0][0] is not '_')
return dict(i for i in vars(obj).items() if i[0][0] != '_')


def copy_attrs(obj, remove=[]):
def copy_attrs(obj, remove=None):
''' Copy attribute values for an object '''
if remove is None:
remove = []
return dict(i for i in attrs(obj).items() if i[0] not in remove)


def render_column_definitions(model):
''' Create SQLite column definitions for an entity model '''
return ['%s %s' % (k, DATA_TYPES[v]) for k, v in attrs(model).items()]
model_attrs = attrs(model).items()
model_attrs = {k: v for k, v in model_attrs if k != 'db'}
return ['%s %s' % (k, DATA_TYPES[v]) for k, v in model_attrs.items()]


def render_create_table_stmt(model):
Expand All @@ -37,17 +41,19 @@ def render_create_table_stmt(model):
return sql.format(**params)


class Database(object):
class Database(object): # pylint: disable=R0205
''' Proxy class to access sqlite3.connect method '''

def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self._connection = None
self.connected = False
self.Model = type('Model%s' % str(self), (Model,), {'db': self})
self.Model = type('Model%s' % str(self), (Model,), {'db': self}) # pylint: disable=C0103

@property
def connection(self):
''' Create SQL connection '''
if self.connected:
return self._connection
self._connection = sqlite3.connect(*self.args, **self.kwargs)
Expand All @@ -56,22 +62,26 @@ def connection(self):
return self._connection

def close(self):
''' Close SQL connection '''
if self.connected:
self.connection.close()
self.connected = False

def commit(self):
''' Commit SQL changes '''
self.connection.commit()

def execute(self, sql, *args):
''' Execute SQL '''
return self.connection.execute(sql, args)

def executescript(self, script):
''' Execute SQL script '''
self.connection.cursor().executescript(script)
self.commit()


class Manager(object):
class Manager(object): # pylint: disable=R0205
''' Data mapper interface (generic repository) for models '''

def __init__(self, db, model, type_check=True):
Expand All @@ -83,19 +93,23 @@ def __init__(self, db, model, type_check=True):
self.db.executescript(render_create_table_stmt(self.model))

def all(self):
''' Get all model objects from database '''
result = self.db.execute('SELECT * FROM %s' % self.table_name)
return (self.create(**row) for row in result.fetchall())

def create(self, **kwargs):
''' Create a model object '''
obj = object.__new__(self.model)
obj.__dict__ = kwargs
return obj

def delete(self, obj):
''' Delete a model object from database '''
sql = 'DELETE from %s WHERE id = ?'
self.db.execute(sql % self.table_name, obj.id)

def get(self, id):
''' Get a model object from database by its id '''
sql = 'SELECT * FROM %s WHERE id = ?' % self.table_name
result = self.db.execute(sql, id)
row = result.fetchone()
Expand All @@ -105,24 +119,28 @@ def get(self, id):
return self.create(**row)

def has(self, id):
''' Check if a model object exists in database by its id '''
sql = 'SELECT id FROM %s WHERE id = ?' % self.table_name
result = self.db.execute(sql, id)
return True if result.fetchall() else False

def save(self, obj):
''' Save a model object '''
if 'id' in obj.__dict__ and self.has(obj.id):
msg = 'Object%s id already registred: %s' % (self.model, obj.id)
raise ValueError(msg)
clone = copy_attrs(obj, remove=['id'])
self.type_check and self._isvalid(clone)
column_names = '%s' % ', '.join(clone.keys())
column_references = '%s' % ', '.join('?' for i in range(len(clone)))
sql = 'INSERT INTO %s (%s) VALUES (%s)' % (self.table_name, column_names, column_references) # noqa
sql = 'INSERT INTO %s (%s) VALUES (%s)'
sql = sql % (self.table_name, column_names, column_references)
result = self.db.execute(sql, *clone.values())
obj.id = result.lastrowid
return obj

def update(self, obj):
''' Update a model object '''
clone = copy_attrs(obj, remove=['id'])
self.type_check and self._isvalid(clone)
where_expressions = '= ?, '.join(clone.keys()) + '= ?'
Expand All @@ -146,27 +164,32 @@ def _isvalid(self, attr_values):
raise TypeError(msg % (attr, attr_types[attr], value_type))


class Model(object):
class Model(object): # pylint: disable=R0205
''' Abstract entity model with an active record interface '''

db = None

def delete(self, type_check=True):
''' Delete this model object '''
return self.__class__.manager(type_check=type_check).delete(self)

def save(self, type_check=True):
''' Save this model object '''
return self.__class__.manager(type_check=type_check).save(self)

def update(self, type_check=True):
''' Update this model object '''
return self.__class__.manager(type_check=type_check).update(self)

@property
def public(self):
''' Return the public model attributes '''
return attrs(self)

def __repr__(self):
return str(self.public)

@classmethod
def manager(cls, db=None, type_check=True):
''' Create a database managet '''
return Manager(db if db else cls.db, cls, type_check)
1 change: 0 additions & 1 deletion tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Post(db.Model):

random = float
text = str

Expand Down

0 comments on commit 7c79694

Please sign in to comment.