Skip to content

Commit

Permalink
⭐ Improve parameterized query support - fixes #793 (#794)
Browse files Browse the repository at this point in the history
* Add parameterized query support

* Revert base Parameter constructor back to it's original signature

* Fix a few typehints and make code more DRY

* add test for PyformatParameter

* fix linting issues

---------

Co-authored-by: Lars Schwegmann <[email protected]>
  • Loading branch information
mvanderlee and larsschwegmann authored Oct 21, 2024
1 parent 53a77eb commit 4072bfb
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 23 deletions.
148 changes: 126 additions & 22 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import uuid
from datetime import date
from enum import Enum
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)

from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
from pypika.utils import (
Expand Down Expand Up @@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
raise NotImplementedError()


def idx_placeholder_gen(idx: int) -> str:
return str(idx + 1)


def named_placeholder_gen(idx: int) -> str:
return f'param{idx + 1}'


class Parameter(Term):
is_aggregate = None

def __init__(self, placeholder: Union[str, int]) -> None:
super().__init__()
self.placeholder = placeholder
self._placeholder = placeholder

@property
def placeholder(self):
return self._placeholder

def get_sql(self, **kwargs: Any) -> str:
return str(self.placeholder)

def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
pass

class QmarkParameter(Parameter):
"""Question mark style, e.g. ...WHERE name=?"""
def get_param_key(self, placeholder: Any, **kwargs):
return placeholder

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "?"
class ListParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = list()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

class NumericParameter(Parameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""
return str(self._placeholder)

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)
def get_parameters(self, **kwargs):
return self._parameters

def update_parameters(self, value: Any, **kwargs):
self._parameters.append(value)

class NamedParameter(Parameter):
"""Named style, e.g. ...WHERE name=:name"""

class DictParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = dict()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[1:]

def update_parameters(self, param_key: Any, value: Any, **kwargs):
self._parameters[param_key] = value


class QmarkParameter(ListParameter):
def get_sql(self, **kwargs):
return '?'


class NumericParameter(ListParameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class FormatParameter(Parameter):
class FormatParameter(ListParameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "%s"


class PyformatParameter(Parameter):
class NamedParameter(DictParameter):
"""Named style, e.g. ...WHERE name=:name"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class PyformatParameter(DictParameter):
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""

def get_sql(self, **kwargs: Any) -> str:
return "%({placeholder})s".format(placeholder=self.placeholder)

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[2:-2]


class Negative(Term):
def __init__(self, term: Term) -> None:
Expand Down Expand Up @@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
return "null"
return str(value)

def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key

def get_sql(
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
parameter: Parameter = None,
**kwargs: Any,
) -> str:
if parameter is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)

# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql, param_key = self._get_param_data(parameter, **kwargs)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)


class ParameterValueWrapper(ValueWrapper):
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
super().__init__(value, alias)
self._parameter = parameter

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key


class JSON(Term):
Expand Down Expand Up @@ -551,6 +654,7 @@ def __init__(
if isinstance(table, str):
# avoid circular import at load time
from pypika.queries import Table

table = Table(table)
self.table = table

Expand Down
112 changes: 112 additions & 0 deletions pypika/tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import date

from pypika import (
FormatParameter,
Expand All @@ -10,6 +11,7 @@
Query,
Tables,
)
from pypika.terms import ListParameter, ParameterValueWrapper


class ParametrizedTests(unittest.TestCase):
Expand Down Expand Up @@ -92,3 +94,113 @@ def test_format_parameter(self):

def test_pyformat_parameter(self):
self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql())


class ParametrizedTestsWithValues(unittest.TestCase):
table_abc, table_efg = Tables("abc", "efg")

def test_param_insert(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')

parameter = QmarkParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql)
self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters())

def test_param_select_join(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == 'foobar')
.join(self.table_efg)
.on(self.table_abc.id == self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
.limit(10)
)

parameter = FormatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10',
sql,
)
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())

def test_param_select_subquery(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == 'foobar')
.where(
self.table_abc.id.isin(
Query.from_(self.table_efg)
.select(self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
)
)
.limit(10)
)

parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}')
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10',
sql,
)
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())

def test_join(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == 'buz')
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == 'bar')
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2',
sql,
)
self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters())

def test_join_with_parameter_value_wrapper(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz'))
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar'))
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar',
sql,
)
self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters())

def test_pyformat_parameter(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')

parameter = PyformatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql)
self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters())
2 changes: 1 addition & 1 deletion pypika/tests/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_init_with_str_table(self):
test_table_name = "test_table"
field = Field(name="name", table=test_table_name)
self.assertEqual(field.table, Table(name=test_table_name))


class FieldHashingTests(TestCase):
def test_tabled_eq_fields_equally_hashed(self):
Expand Down

0 comments on commit 4072bfb

Please sign in to comment.