Skip to content

Commit

Permalink
filter on queries
Browse files Browse the repository at this point in the history
  • Loading branch information
CrispenGari committed Feb 5, 2024
1 parent 9953cd1 commit 0c6f337
Show file tree
Hide file tree
Showing 11 changed files with 1,602 additions and 67 deletions.
1,139 changes: 1,139 additions & 0 deletions dataloom.sql

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion dataloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UnsupportedDialectException,
UnsupportedTypeException,
)
from dataloom.types import Order, Include
from dataloom.types import Order, Include, Filter
from dataloom.model import Model
from dataloom.model import (
PrimaryKeyColumn,
Expand All @@ -21,6 +21,7 @@
)

__all__ = [
Filter,
Order,
Include,
MySQLConfig,
Expand Down
4 changes: 4 additions & 0 deletions dataloom/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ class InvalidColumnValuesException(ValueError):

class UnknownColumnException(ValueError):
pass


class InvalidOperatorException(ValueError):
pass
50 changes: 37 additions & 13 deletions dataloom/loom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from dataloom.conn import ConnectionOptionsFactory
from dataloom.utils import logger_function, get_child_table_columns
from typing import Optional
from dataloom.types import Order, Include
from dataloom.types import DIALECT_LITERAL
from dataloom.types import Order, Include, DIALECT_LITERAL, Filter


class Dataloom:
Expand Down Expand Up @@ -300,17 +299,24 @@ def find_many(
) -> list:
sql, params, fields = instance._get_select_where_stm(
dialect=self.dialect,
args=filters,
filters=filters,
select=select,
limit=limit,
offset=offset,
order=order,
include=include,
)
data = list()
data = []
rows = self._execute_sql(sql, fetchall=True, args=params)
for row in rows:
json = dict(zip(fields, row))
data.append(json if return_dict else instance(**json))
res = self.__map_relationships(
instance=instance,
row=row,
parent_fields=fields,
include=include,
return_dict=return_dict,
)
data.append(res)
return data

def find_all(
Expand All @@ -323,6 +329,7 @@ def find_all(
offset: Optional[int] = None,
order: Optional[list[Order]] = [],
) -> list:
return_dict = True
sql, params, fields = instance._get_select_where_stm(
dialect=self.dialect,
select=select,
Expand All @@ -331,11 +338,17 @@ def find_all(
order=order,
include=include,
)
data = list()
data = []
rows = self._execute_sql(sql, fetchall=True)
for row in rows:
json = dict(zip(fields, row))
data.append(json if return_dict else instance(**json))
res = self.__map_relationships(
instance=instance,
row=row,
parent_fields=fields,
include=include,
return_dict=return_dict,
)
data.append(res)
return data

def __map_relationships(
Expand Down Expand Up @@ -368,6 +381,7 @@ def find_by_pk(
include: list[Include] = [],
return_dict: bool = True,
):
return_dict = True
# what is the name of the primary key column? well we will find out
sql, fields = instance._get_select_by_pk_stm(
dialect=self.dialect, select=select, include=include
Expand All @@ -386,20 +400,30 @@ def find_by_pk(
def find_one(
self,
instance: Model,
filters: dict = {},
filters: Optional[Filter | list[Filter]] = None,
select: list[str] = [],
include: list[Include] = [],
return_dict: bool = True,
offset: Optional[int] = None,
):
return_dict = True
sql, params, fields = instance._get_select_where_stm(
dialect=self.dialect, args=filters, select=select, offset=offset
dialect=self.dialect,
filters=filters,
select=select,
offset=offset,
include=include,
)
row = self._execute_sql(sql, args=params, fetchone=True)
if row is None:
return None
json = dict(zip(fields, row))
return json if return_dict else instance(**json)
return self.__map_relationships(
instance=instance,
row=row,
parent_fields=fields,
include=include,
return_dict=return_dict,
)

def update_by_pk(self, instance: Model, pk, values: dict = {}):
sql, args = instance._get_update_by_pk_stm(dialect=self.dialect, args=values)
Expand Down
78 changes: 57 additions & 21 deletions dataloom/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from dataloom.statements import GetStatement
from dataloom.types import Order, Include
from typing import Optional
from dataloom.types import DIALECT_LITERAL
from dataloom.types import DIALECT_LITERAL, Filter
from dataloom.utils import get_operator


class Model:
Expand Down Expand Up @@ -200,18 +201,20 @@ def _get_child_table_params(include: Include):
def _get_select_where_stm(
cls,
dialect: DIALECT_LITERAL,
args: dict = {},
filters: Optional[Filter | list[Filter]] = None,
select: list[str] = [],
limit: Optional[int] = None,
offset: Optional[int] = None,
order: Optional[list[Order]] = [],
include: list[Include] = [],
):
pk_name = None
orders = list()
fields = []
filters = []
params = []
query_params = []
includes = []
# what are the foreign keys?
fks = dict()

for _include in include:
includes.append(cls._get_child_table_params(_include))
Expand All @@ -221,8 +224,11 @@ def _get_select_where_stm(
fields.append(name)
elif isinstance(field, ForeignKeyColumn):
fields.append(name)
table_name = field.table._get_table_name()
fks[table_name] = name
elif isinstance(field, PrimaryKeyColumn):
fields.append(name)
pk_name = name
elif isinstance(field, CreatedAtColumn):
fields.append(name)
elif isinstance(field, UpdatedAtColumn):
Expand All @@ -244,46 +250,76 @@ def _get_select_where_stm(
raise UnknownColumnException(
f'The table "{cls._get_table_name()}" does not have a column "{column}".'
)
for key, value in args.items():
_key = (
f'"{key}" = %s'
if dialect == "postgres"
else f"`{
key}` = {'%s' if dialect == 'mysql' else '?'}"
)
if key not in fields:
raise UnknownColumnException(
f"Table {cls._get_table_name()} does not have column '{key}'."
)
else:
filters.append(_key)
params.append(value)

if filters is not None:
if isinstance(filters, list):
for idx, filter in enumerate(filters):
key = filter.column
if key not in fields:
raise UnknownColumnException(
f"Table {cls._get_table_name()} does not have column '{key}'."
)
op = get_operator(filter.operator)
join = (
""
if len(filters) == idx + 1
else f" {filter.join_next_filter_with}"
)
_key = (
f'"{key}" {op} %s {join}'
if dialect == "postgres"
else f"`{key}` {op} {'%s' if dialect == 'mysql' else '?'} {join}"
)
query_params.append((_key, filter.value))
else:
filter = filters
key = filter.column
if key not in fields:
raise UnknownColumnException(
f"Table {cls._get_table_name()} does not have column '{key}'."
)
op = get_operator(filter.operator)
_key = (
f'"{key}" {op} %s'
if dialect == "postgres"
else f"`{key}` {op} {'%s' if dialect == 'mysql' else '?'}"
)
query_params.append((_key, filter.value))
if dialect == "postgres" or "mysql" or "sqlite":
if len(filters) == 0:
if len(query_params) == 0:
sql = GetStatement(
dialect=dialect, model=cls, table_name=cls._get_table_name()
)._get_select_command(
fields=fields if len(select) == 0 else select,
limit=limit,
offset=offset,
orders=orders,
includes=includes,
fks=fks,
pk_name=pk_name,
)
else:
sql = GetStatement(
dialect=dialect, model=cls, table_name=cls._get_table_name()
)._get_select_where_command(
filters=filters,
query_params=query_params,
fields=fields if len(select) == 0 else select,
limit=limit,
offset=offset,
orders=orders,
includes=includes,
fks=fks,
pk_name=pk_name,
)
else:
raise UnsupportedDialectException(
"The dialect passed is not supported the supported dialects are: {'postgres', 'mysql', 'sqlite'}"
)
return sql, params, fields if len(select) == 0 else select
return (
sql,
[qp[1] for qp in query_params],
fields if len(select) == 0 else select,
)

@classmethod
def _get_select_by_pk_stm(
Expand Down
Loading

0 comments on commit 0c6f337

Please sign in to comment.