From 8a964d18cf442d36258af1626ba155fcfab02685 Mon Sep 17 00:00:00 2001 From: M1ha Date: Mon, 30 Apr 2018 11:12:50 +0500 Subject: [PATCH 1/8] 1) Moved values to WITH section in update query. 2) Restricted getting protected functions from query.py --- src/django_pg_bulk_update/query.py | 94 ++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/src/django_pg_bulk_update/query.py b/src/django_pg_bulk_update/query.py index 26ae0ca..20596dc 100644 --- a/src/django_pg_bulk_update/query.py +++ b/src/django_pg_bulk_update/query.py @@ -18,6 +18,8 @@ TSetFunctionsValid from .utils import batched_operation +__all__ = ['pdnf_clause', 'bulk_update', 'bulk_update_or_create'] + def _validate_field_names(parameter_name, field_names): # type: (str, TFieldNames) -> List[str] @@ -247,11 +249,10 @@ def pdnf_clause(key_fields, field_values, key_fields_ops=()): return or_cond -def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_ops): - # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> int +def _with_values_query_part(model, values, conn, set_functions, key_fields_ops): + # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> Tuple[List[str], str, List[Any]] """ - Does bulk update, skipping parameters validation. - It is used for speed up in bulk_update_or_create, where parameters are already formatted. + Forms query part, selecting input values :param model: Model to update, a subclass of django.db.models.Model :param values: Data to update. All items must update same fields!!! Dict of key_values_tuple: update_fields_dict @@ -260,27 +261,15 @@ def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_op Should be a dict of field name as key, function class as value. :param key_fields_ops: Key fields compare operators. A tuple with (field_name from key_fields, operation name) elements - :return: Number of records updated - """ - upd_keys_tuple = tuple(set_functions.keys()) - - # No any values to update. Return that everything is done. - if not upd_keys_tuple or not values: - return len(values) - - # Query template. We will form its substitutes in next sections - query = """ - UPDATE %s AS t SET %s - FROM ( - VALUES %s - ) AS sel(%s) - WHERE %s; + :return: A tuple of sql and it's parameters """ + tpl = "WITH vals(%s) AS (VALUES %s)" # Table we save data to db_table = model._meta.db_table + upd_keys_tuple = tuple(set_functions.keys()) - # Form data for VALUES() select. + # Form data for VALUES section # It includes both keys and update data: keys will be used in WHERE section, while update data in SET section values_items = [] values_update_params = [] @@ -326,12 +315,42 @@ def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_op sel_sql = ', '.join(sel_items + sel_key_items) + return sel_key_items, tpl % (sel_sql, values_sql), values_update_params + + +def _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_fields_ops): + # type: (Type[Model], List[str], DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> Tuple[str, List[Any]] + """ + Forms bulk update query part without values, counting that all keys and values are already in vals table + :param model: Model to update, a subclass of django.db.models.Model + :param sel_key_items: Names of field in vals table. + Key fields are prefixed with key_%d__ + Values fields are prefixed with upd__ + :param conn: Database connection used + :param set_functions: Functions to set values. + Should be a dict of field name as key, function class as value. + :param key_fields_ops: Key fields compare operators. + A tuple with (field_name from key_fields, operation name) elements + :return: A tuple of sql and it's parameters + """ + upd_keys_tuple = tuple(set_functions.keys()) + + # Query template. We will form its substitutes in next sections + query = """ + UPDATE %s AS t SET %s + FROM vals + WHERE %s; + """ + + # Table we save data to + db_table = model._meta.db_table + # Form data for WHERE section # Remember that field names in sel table have prefixes. where_items = [] for (key_field, op), sel_field in zip(key_fields_ops, sel_key_items): table_field = '"t"."%s"' % model._meta.get_field(key_field).column - prefixed_sel_field = '"sel"."%s"' % sel_field + prefixed_sel_field = '"vals"."%s"' % sel_field where_items.append(op.get_sql(table_field, prefixed_sel_field)) where_sql = ' AND '.join(where_items) @@ -340,17 +359,42 @@ def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_op for field_name in upd_keys_tuple: func_cls = set_functions[field_name] f = model._meta.get_field(field_name) - func_sql, params = func_cls.get_sql(f, '"sel"."upd__%s"' % field_name, conn, val_as_param=False) + func_sql, params = func_cls.get_sql(f, '"vals"."upd__%s"' % field_name, conn, val_as_param=False) set_items.append(func_sql) set_params.extend(params) set_sql = ', '.join(set_items) - # Substitute query placeholders - query = query % ('"%s"' % db_table, set_sql, values_sql, sel_sql, where_sql) + # Substitute query placeholders and concatenate with VALUES section + query = query % ('"%s"' % db_table, set_sql, where_sql) + return query, set_params + + +def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_ops): + # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> int + """ + Does bulk update, skipping parameters validation. + It is used for speed up in bulk_update_or_create, where parameters are already formatted. + :param model: Model to update, a subclass of django.db.models.Model + :param values: Data to update. All items must update same fields!!! + Dict of key_values_tuple: update_fields_dict + :param conn: Database connection used + :param set_functions: Functions to set values. + Should be a dict of field name as key, function class as value. + :param key_fields_ops: Key fields compare operators. + A tuple with (field_name from key_fields, operation name) elements + :return: Number of records updated + """ + # No any values to update. Return that everything is done. + if not set_functions or not values: + return len(values) + + sel_key_items, values_sql, values_params = _with_values_query_part(model, values, conn, set_functions, + key_fields_ops) + upd_sql, upd_params = _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_fields_ops) # Execute query cursor = conn.cursor() - cursor.execute(query, params=set_params + values_update_params) + cursor.execute(values_sql + upd_sql, params=values_params + upd_params) return cursor.rowcount From c5d1b1bbb2a2729a2ccc1a7254bb391cad53d9c4 Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 12:23:19 +0500 Subject: [PATCH 2/8] 1) Refactoring of all code: added FieldDescriptor class in order to pass parameters easier 2) Added logging 3) bulk_update_or_create now works in on query on PostgreSQL 9.5+ via INSERT ... ON CONFLICT. --- src/django_pg_bulk_update/clause_operators.py | 34 +- src/django_pg_bulk_update/compatibility.py | 23 +- src/django_pg_bulk_update/manager.py | 16 +- src/django_pg_bulk_update/query.py | 572 +++++++++++------- src/django_pg_bulk_update/set_functions.py | 95 ++- src/django_pg_bulk_update/types.py | 117 +++- src/django_pg_bulk_update/utils.py | 35 +- tests/migrations/0001_initial.py | 3 +- tests/models.py | 4 + tests/{test_settings.py => settings.py} | 15 + tests/test_bulk_update.py | 90 ++- tests/test_bulk_update_or_create.py | 142 +++-- tests/test_pdnf_clause.py | 20 +- 13 files changed, 725 insertions(+), 441 deletions(-) rename tests/{test_settings.py => settings.py} (76%) diff --git a/src/django_pg_bulk_update/clause_operators.py b/src/django_pg_bulk_update/clause_operators.py index 7aa7be2..b22dddb 100644 --- a/src/django_pg_bulk_update/clause_operators.py +++ b/src/django_pg_bulk_update/clause_operators.py @@ -4,7 +4,7 @@ from typing import Type, Optional, Any, Tuple, Iterable, Dict from django.db import DefaultConnectionProxy -from django.db.models import Field, Model +from django.db.models import Field from .compatibility import array_available, get_field_db_type from .utils import get_subclasses, format_field_value @@ -17,7 +17,7 @@ class AbstractClauseOperator(object): def get_django_filters(self, name, value): # type: (str, Any) -> Dict[str, Any] """ - This method should return parameter name to use in django QuerySet.fillter() kwargs + This method should return parameter name to use in django QuerySet.filter() kwargs :param name: Name of the parameter :param value: Value of the parameter :return: kwargs to pass to Q() object constructor @@ -25,7 +25,7 @@ def get_django_filters(self, name, value): raise NotImplementedError("%s must implement get_django_filter method" % self.__class__.__name__) @classmethod - def get_operation_by_name(cls, name): # type: (str) -> Optional[Type[AbstractClauseOperator]] + def get_operator_by_name(cls, name): # type: (str) -> Optional[Type[AbstractClauseOperator]] """ Finds subclass of AbstractOperation applicable to given operation name :param name: String name to search @@ -34,7 +34,7 @@ def get_operation_by_name(cls, name): # type: (str) -> Optional[Type[AbstractCl try: return next(sub_cls for sub_cls in get_subclasses(cls, recursive=True) if name in sub_cls.names) except StopIteration: - raise AssertionError("Operator with name '%s' doesn't exist" % name) + raise ValueError("Operator with name '%s' doesn't exist" % name) def get_sql_operator(self): # type: () -> str """ @@ -52,42 +52,24 @@ def get_sql(self, table_field, value): # type: (str, str) -> str """ return "%s %s %s" % (table_field, self.get_sql_operator(), value) - def get_null_fix_sql(self, model, field_name, connection): - # type: (Type[Model], str, DefaultConnectionProxy) -> str - """ - Bug fix. Postgres wants to know exact type of field to save it - This fake update value is used for each saved column in order to get it's type - :param model: Django model subclass - :param field_name: Name of field fix is got for - :param connection: Database connection used - :return: SQL string - """ - db_table = model._meta.db_table - field = model._meta.get_field(field_name) - return '(SELECT "{key}" FROM "{table}" LIMIT 0)'.format(key=field.column, table=db_table) - - def format_field_value(self, field, val, connection, **kwargs): - # type: (Field, Any, DefaultConnectionProxy, **Any) -> Tuple[str, Tuple[Any]] + def format_field_value(self, field, val, connection, cast_type=False, **kwargs): + # type: (Field, Any, DefaultConnectionProxy, bool, **Any) -> Tuple[str, Tuple[Any]] """ Formats value, according to field rules :param field: Django field to take format from :param val: Value to format :param connection: Connection used to update data + :param cast_type: Adds type casting to sql if flag is True :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ - return format_field_value(field, val, connection) + return format_field_value(field, val, connection, cast_type=cast_type) class AbstractArrayValueOperator(AbstractClauseOperator): """ Abstract class partial, that handles an array of field values as input """ - def get_null_fix_sql(self, model, field_name, connection): - field = model._meta.get_field(field_name) - db_type = get_field_db_type(field, connection) - return '(SELECT ARRAY[]::%s[] LIMIT 0)' % db_type - def format_field_value(self, field, val, connection, **kwargs): assert isinstance(val, Iterable), "'%s' value must be iterable" % self.__class__.__name__ diff --git a/src/django_pg_bulk_update/compatibility.py b/src/django_pg_bulk_update/compatibility.py index 15c5f4f..cdd6f0c 100644 --- a/src/django_pg_bulk_update/compatibility.py +++ b/src/django_pg_bulk_update/compatibility.py @@ -63,11 +63,12 @@ def hstore_serialize(value): # type: (Dict[Any, Any]) -> Dict[str, str] return val -def get_postgres_version(using=None, as_tuple=True): # type: (Optional[str], bool) -> Union(Tuple[int], int) +def get_postgres_version(using=None, as_tuple=True): + # type: (Optional[str], bool) -> Union[Tuple[int], int] """ - Returns Postgres server verion used + Returns Postgres server version used :param using: Connection alias to use - :param as_tuple: If true, returns result as tuple, otherwize as concatenated integer + :param as_tuple: If true, returns result as tuple, otherwise as concatenated integer :return: Database version as tuple (major, minor, revision) if as_tuple is true. A single number major*10000 + minor*100 + revision if false. """ @@ -76,23 +77,19 @@ def get_postgres_version(using=None, as_tuple=True): # type: (Optional[str], bo return (num / 10000, num % 10000 / 100, num % 100) if as_tuple else num -def get_field_db_type(field, connection): # type: (models.Field, DefaultConnectionProxy) -> str +def get_field_db_type(field, conn): + # type: (models.Field, DefaultConnectionProxy) -> str """ Get database field type used for this field. :param field: django.db.models.Field instance - :param connection: Datbase connection used + :param conn: Database connection used :return: Database type name (str) """ # We should resolve value as array for IN operator. # db_type() as id field returned 'serial' instead of 'integer' here - # reL_db_type() return integer, but it is not available before django 1.10 - db_type = field.db_type(connection) - if db_type == 'serial': - db_type = 'integer' - elif db_type == 'bigserial': - db_type = 'biginteger' - - return db_type + # rel_db_type() return integer, but it is not available before django 1.10 + db_type = field.db_type(conn) + return db_type.replace('serial', 'integer') # Postgres 9.4 has JSONB support, but doesn't support concat operator (||) diff --git a/src/django_pg_bulk_update/manager.py b/src/django_pg_bulk_update/manager.py index 30ac80e..39c9a4b 100644 --- a/src/django_pg_bulk_update/manager.py +++ b/src/django_pg_bulk_update/manager.py @@ -99,12 +99,12 @@ class TestModel(models.Model): self._for_write = True using = self.db - return bulk_update(self.model, values, key_fields=key_fields, using=using, set_functions=set_functions, + return bulk_update(self.model, values, key_fds=key_fields, using=using, set_functions=set_functions, key_fields_ops=key_fields_ops, batch_size=batch_size, batch_delay=batch_delay) - def bulk_update_or_create(self, values, key_fields='id', set_functions=None, update=True, batch_size=None, - batch_delay=0): - # type: (TUpdateValues, TFieldNames, TSetFunctions, bool, Optional[int], float) -> Tuple[int, int] + def bulk_update_or_create(self, values, key_fields='id', set_functions=None, update=True, key_is_unique=True, + batch_size=None, batch_delay=0): + # type: (TUpdateValues, TFieldNames, TSetFunctions, bool, bool, Optional[int], float) -> int """ Searches for records, given in values by key_fields. If records are found, updates them from values. If not found - creates them from values. Note, that all fields without default value must be present in values. @@ -125,17 +125,19 @@ def bulk_update_or_create(self, values, key_fields='id', set_functions=None, upd Functions: [eq, =; incr, +; concat, ||] Example: {'name': 'eq', 'int_fields': 'incr'} :param update: If this flag is not set, existing records will not be updated + :param key_is_unique: Settings this flag to False forces library to use 3-query transactional update, + not INSERT ... ON CONFLICT. :param batch_size: Optional. If given, data is split it into batches of given size. Each batch is queried independently. :param batch_delay: Delay in seconds between batches execution, if batch_size is not None. - :return: A tuple (number of records created, number of records updated) + :return: Number of records created or updated """ self._for_write = True using = self.db return bulk_update_or_create(self.model, values, key_fields=key_fields, using=using, - set_functions=set_functions, update=update, batch_size=batch_size, - batch_delay=batch_delay) + set_functions=set_functions, update=update, key_is_unique=key_is_unique, + batch_size=batch_size, batch_delay=batch_delay) class BulkUpdateManager(models.Manager, BulkUpdateManagerMixin): diff --git a/src/django_pg_bulk_update/query.py b/src/django_pg_bulk_update/query.py index 461adca..9e7904f 100644 --- a/src/django_pg_bulk_update/query.py +++ b/src/django_pg_bulk_update/query.py @@ -5,82 +5,79 @@ import inspect import json from collections import Iterable +from itertools import chain +from logging import getLogger import six from django.db import transaction, connection, connections, DefaultConnectionProxy from django.db.models import Model, Q from typing import Any, Type, Iterable as TIterable, Union, Optional, List, Tuple -from .clause_operators import AbstractClauseOperator, EqualClauseOperator -from .compatibility import zip_longest, get_postgres_version -from .set_functions import EqualSetFunction, AbstractSetFunction +from .compatibility import get_postgres_version +from .set_functions import AbstractSetFunction from .types import TOperators, TFieldNames, TUpdateValues, TSetFunctions, TOperatorsValid, TUpdateValuesValid, \ - TSetFunctionsValid + TSetFunctionsValid, FieldDescriptor from .utils import batched_operation + __all__ = ['pdnf_clause', 'bulk_update', 'bulk_update_or_create'] +logger = getLogger('django-pg-bulk-update') -def _validate_field_names(parameter_name, field_names): - # type: (str, TFieldNames) -> List[str] +def _validate_field_names(field_names): + # type: (TFieldNames) -> Tuple[FieldDescriptor] """ Validates field_names. It can be a string for single field or an iterable of strings for multiple fields. - :param parameter_name: A name of parameter validated to output in exception :param field_names: Field names to validate - :return: A list of strings - formatted field types + :return: A tuple of strings - formatted field types :raises AssertionError: If validation is not passed """ - error_message = "'%s' parameter must be iterable of strings" % parameter_name + error_message = "'key_fields' parameter must be iterable of strings" if isinstance(field_names, six.string_types): - return [field_names] + return FieldDescriptor(field_names), # comma is not a bug, I need tuple returned elif isinstance(field_names, Iterable): field_names = list(field_names) for name in field_names: - assert isinstance(name, six.string_types), error_message - return field_names + if not isinstance(name, six.string_types): + raise TypeError(error_message) + return tuple(FieldDescriptor(name) for name in field_names) else: - raise AssertionError(error_message) + raise TypeError(error_message) -def _validate_operators(field_names, operators, param_name='key_fields_ops'): - # type: (List[str], TOperators, str) -> TOperatorsValid +def _validate_operators(key_fds, operators): + # type: (Tuple[FieldDescriptor], TOperators) -> TOperatorsValid """ Validates operators and gets a dict of database filters with field_name as key Order of dict is equal to field_names order - :param field_names: A list of field_names, already validated + :param key_fds: A tuple of FieldDescriptor objects. These objects will be modified. :param operators: Operations, not validated. - :param param_name: Name of parameter to output in exception - :return: An ordered dict of field_name: (db_filter pairs, inverse) + :return: A tuple of field descriptors with operators """ - # Format operations as tuple (field name, AbstractClauseOperator()) + # Format operators (field name, AbstractClauseOperator()) if isinstance(operators, dict): - assert len(set(operators.keys()) - set(field_names)) == 0,\ - "Some operators are not present in %s" % param_name - operators = tuple( - (name, EqualClauseOperator() if name not in operators else operators[name]) - for name in field_names - ) + if len(set(operators.keys()) - {f.name for f in key_fds}) != 0: + raise ValueError("Some operators are not present in 'key_field_ops'") + for fd in key_fds: + fd.key_operator = operators.get(fd.name) else: - assert isinstance(operators, Iterable), \ - "'%s' parameter must be iterable of strings or AbstractClauseOperator instances" % param_name - operators = tuple(zip_longest(field_names, operators, fillvalue=EqualClauseOperator())) + if not isinstance(operators, Iterable): + raise TypeError("'key_field_ops' parameter must be iterable of strings or AbstractClauseOperator instances") + operators = tuple(operators) + for i, fd in enumerate(key_fds): + fd.key_operator = operators[i] if i < len(operators) else None - res = [] - for field_name, op in operators: - if isinstance(op, AbstractClauseOperator): - res.append((field_name, op)) - elif isinstance(op, six.string_types): - res.append((field_name, AbstractClauseOperator.get_operation_by_name(op)())) - else: - raise AssertionError("Invalid operator '%s'" % str(op)) + # Add prefix to all descriptors + for i, fd in enumerate(key_fds): + fd.set_prefix('key', index=i) - return tuple(res) + return key_fds -def _validate_update_values(key_fields, values, param_name='values'): - # type: (List[str], TUpdateValues, str) -> Tuple[Tuple[str], TUpdateValuesValid] +def _validate_update_values(key_fds, values): + # type: (Tuple[FieldDescriptor], TUpdateValues) -> Tuple[Tuple[FieldDescriptor], TUpdateValuesValid] """ Parses and validates input data for bulk_update and bulk_update_or_create. It can come in 2 forms: @@ -90,11 +87,10 @@ def _validate_update_values(key_fields, values, param_name='values'): - key_values can be iterable or single object. - If iterable, key_values length must be equal to key_fields length. - If single object, key_fields is expected to have 1 element - :param key_fields: Field names, by which items would be selected, already validated. + :param key_fds: A tuple of FieldDescriptor objects, by which data will be selected :param values: Input data as given - :param param_name: Name of parameter containing values to use in exception :return: Returns a tuple: - + A tuple with names of keys to update (which are not in key_fields) + + A tuple with FieldDescriptor objects to update (which are not in key_field_descriptors) + A dict, keys are tuples of key_fields values, and values are update_values """ upd_keys_tuple = tuple() @@ -106,8 +102,8 @@ def _validate_update_values(key_fields, values, param_name='values'): if not isinstance(keys, tuple): keys = (keys,) - if len(keys) != len(key_fields): - raise AssertionError("Length of key tuple is not equal to key_fields length") + if len(keys) != len(key_fds): + raise ValueError("Length of key tuple is not equal to key_fields length") # First element. Let's think, that it's fields are updates if not upd_keys_tuple: @@ -115,83 +111,81 @@ def _validate_update_values(key_fields, values, param_name='values'): # Not first element. Check that all updates have equal fields elif tuple(sorted(updates.keys())) != upd_keys_tuple: - raise AssertionError("All update data must update same fields") + raise ValueError("All update data must update same fields") # keys may have changed it's format result[keys] = updates elif isinstance(values, Iterable): for item in values: - assert isinstance(item, dict), "All items of iterable must be dicts" - - if set(key_fields) - set(item.keys()): - raise AssertionError("One of update items doesn't contain all key fields") + if not isinstance(item, dict): + raise TypeError("All items of iterable must be dicts") # First element. Let's think, that it's fields are updates + key_field_names = {f.name for f in key_fds} + if key_field_names - set(item.keys()): + raise ValueError("One of update items doesn't contain all key fields") + if not upd_keys_tuple: - upd_keys_tuple = tuple(set(item.keys()) - set(key_fields)) + upd_keys_tuple = tuple(set(item.keys()) - key_field_names) # Not first element. Check that all updates have equal fields - elif set(upd_keys_tuple) | set(key_fields) != set(item.keys()): - raise AssertionError("All update data must update same fields") + elif set(upd_keys_tuple) | key_field_names != set(item.keys()): + raise ValueError("All update data must update same fields") # Split into keys and update values upd_key_values = [] - for f in key_fields: - if isinstance(item[f], dict): - raise AssertionError("Dict is currently not supported as key field") - elif isinstance(item[f], Iterable) and not isinstance(item[f], six.string_types): - upd_key_values.append(tuple(item[f])) + for fd in key_fds: + if isinstance(item[fd.name], dict): + raise TypeError("Dict is currently not supported as key field") + elif isinstance(item[fd.name], Iterable) and not isinstance(item[fd.name], six.string_types): + upd_key_values.append(tuple(item[fd.name])) else: - upd_key_values.append(item[f]) + upd_key_values.append(item[fd.name]) upd_values = {f: item[f] for f in upd_keys_tuple} result[tuple(upd_key_values)] = upd_values else: - raise AssertionError("'%s' parameter must be dict or Iterable" % param_name) + raise TypeError("'values' parameter must be dict or Iterable") + + descriptors = tuple(FieldDescriptor(name) for name in upd_keys_tuple) - return upd_keys_tuple, result + # Add prefix to all descriptors + for name in descriptors: + name.set_prefix('upd') + return descriptors, result -def _validate_set_functions(model, upd_keys_tuple, functions, param_name='set_functions'): - # type: (Type[Model], Tuple[str], TSetFunctions, str) -> TSetFunctionsValid + +def _validate_set_functions(model, upd_fds, functions): + # type: (Type[Model], Tuple[FieldDescriptor], TSetFunctions) -> TSetFunctionsValid """ Validates set functions. It should be a dict with field name as key and function name or AbstractSetFunction instance as value Default set function is EqualSetFunction :param model: Model updated - :param upd_keys_tuple: A tuple of field names to update + :param upd_fds: A tuple of FieldDescriptors to update. It will be modified. :param functions: Functions to validate - :param param_name: Name of the parameter to use in exception - :return: A dict with field name as key and AbstractSetFunction instance as value + :return: A tuple of FieldDescriptor objects with set functions. """ functions = functions or {} - assert isinstance(functions, dict), "'%s' must be a dict instance" % param_name - upd_keys_set = set(upd_keys_tuple) + if not isinstance(functions, dict): + raise TypeError("'set_functions' must be a dict instance") - res = {} - for field_key, func in functions.items(): - assert field_key in upd_keys_set, "'%s' parameter has field name '%s' which will not be updated" \ - % (param_name, field_key) - if isinstance(func, six.string_types): - set_func = AbstractSetFunction.get_function_by_name(func)() - elif isinstance(func, AbstractSetFunction): - set_func = func - else: - raise AssertionError("'%s[%s]' parameter must be string or AbstractSetFunction subclass" - % (param_name, field_key)) + for k, v in functions.items(): + if not isinstance(k, six.string_types): + raise ValueError("'set_functions' keys must be strings") - field = model._meta.get_field(field_key) - assert set_func.field_is_supported(field), "'%s' doesn't support '%s' field" \ - % (set_func.__class__.__name__, field_key) + if not isinstance(v, (six.string_types, AbstractSetFunction)): + raise ValueError("'set_functions' values must be string or AbstractSetFunction instance") - res[field_key] = set_func - # Set default function - for key in upd_keys_set - set(res.keys()): - res[key] = EqualSetFunction() + for f in upd_fds: + f.set_function = functions.get(f.name) + if not f.set_function.field_is_supported(f.get_field(model)): + raise ValueError("'%s' doesn't support '%s' field" % (f.set_function.__class__.__name__, f.name)) - return res + return upd_fds def pdnf_clause(key_fields, field_values, key_fields_ops=()): @@ -212,10 +206,11 @@ def pdnf_clause(key_fields, field_values, key_fields_ops=()): https://docs.djangoproject.com/en/2.0/topics/db/queries/#complex-lookups-with-q-objects """ # Validate input data - key_fields = _validate_field_names("key_fields", key_fields) - key_fields_ops = _validate_operators(key_fields, key_fields_ops) + key_fds = _validate_field_names(key_fields) + key_fds = _validate_operators(key_fds, key_fields_ops) - assert isinstance(field_values, Iterable), "field_values must be iterable of tuples or dicts" + if not isinstance(field_values, Iterable): + raise TypeError("field_values must be iterable of tuples or dicts") field_values = list(field_values) if len(field_values) == 0: @@ -224,116 +219,128 @@ def pdnf_clause(key_fields, field_values, key_fields_ops=()): or_cond = Q() for values_item in field_values: - assert isinstance(values_item, (dict, Iterable)), "Each field_values item must be dict or iterable" - assert len(values_item) == len(key_fields), \ - "All field_values must contain all fields from 'field_names' parameter" + if not isinstance(values_item, (dict, Iterable)): + raise TypeError("Each field_values item must be dict or iterable") + if len(values_item) != len(key_fds): + raise ValueError("All field_values must contain all fields from 'field_names' parameter") and_cond = Q() - for i, name in enumerate(key_fields): + for i, fd in enumerate(key_fds): if isinstance(values_item, dict): - assert name in values_item, "field_values dict '%s' doesn't have key '%s'" \ - % (json.dumps(values_item), name) - value = values_item[name] + if fd.name not in values_item: + raise ValueError("field_values dict '%s' doesn't have key '%s'" % (json.dumps(values_item), fd.name)) + value = values_item[fd.name] elif isinstance(values_item, Iterable): values_item = list(values_item) value = values_item[i] else: - raise AssertionError("Each field_values item must be dict or iterable") + raise TypeError("Each field_values item must be dict or iterable") - op = key_fields_ops[i][1] - kwargs = op.get_django_filters(name, value) - and_cond &= ~Q(**kwargs) if op.inverse else Q(**kwargs) + kwargs = fd.key_operator.get_django_filters(fd.name, value) + and_cond &= ~Q(**kwargs) if fd.key_operator.inverse else Q(**kwargs) or_cond |= and_cond return or_cond -def _with_values_query_part(model, values, conn, set_functions, key_fields_ops): - # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> Tuple[List[str], str, List[Any]] +def _get_default_fds(model, existing_fds): + # type: (Type[Model], Tuple[FieldDescriptor]) -> Tuple[FieldDescriptor] + """ + Finds model fields not absent in existing_fds and returns a Tuple of FieldDescriptors for them + :param model: Model instance + :param existing_fds: Already defined FileDescriptor objects + :return: A tuple of FileDescriptor objects + """ + existing_fields = {fd.get_field(model) for fd in existing_fds} + result = [] + for f in model._meta.get_fields(): + if f not in existing_fields: + desc = FieldDescriptor(f.attname) + desc.set_prefix('def') + result.append(desc) + return tuple(result) + + +def _generate_fds_sql(model, conn, fds, values, for_set, cast_type): + # type: (Type[Model], DefaultConnectionProxy, Tuple[FieldDescriptor], Iterable[Any], bool, bool) -> Tuple[List[str], List[Any]] + """ + Generates + :param fds: + :param for_set: + :return: + """ + sql_list, params_list = [], [] + for fd, val in zip(fds, values): + # These would not be different for different update objects and can be generated once + field = fd.get_field(model) + format_base = fd.set_function if for_set else fd.key_operator + item_sql, item_upd_params = format_base.format_field_value(field, val, conn, cast_type=cast_type) + sql_list.append(item_sql) + params_list.extend(item_upd_params) + + return sql_list, params_list + + +def _with_values_query_part(model, values, conn, key_fds, upd_fds, default_fds=()): + # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, Tuple[FieldDescriptor], Tuple[FieldDescriptor], Tuple[FieldDescriptor]) -> Tuple[str, List[Any]] """ Forms query part, selecting input values :param model: Model to update, a subclass of django.db.models.Model :param values: Data to update. All items must update same fields!!! Dict of key_values_tuple: update_fields_dict :param conn: Database connection used - :param set_functions: Functions to set values. - Should be a dict of field name as key, function class as value. - :param key_fields_ops: Key fields compare operators. - A tuple with (field_name from key_fields, operation name) elements - :return: A tuple of sql and it's parameters + :return: Names of fields in select. A tuple of sql and it's parameters """ tpl = "WITH vals(%s) AS (VALUES %s)" - # Table we save data to - db_table = model._meta.db_table - upd_keys_tuple = tuple(set_functions.keys()) - # Form data for VALUES section # It includes both keys and update data: keys will be used in WHERE section, while update data in SET section values_items = [] values_update_params = [] - # Bug fix. Postgres wants to know exact type of field to save it - # This fake update value is used for each saved column in order to get it's type - select_type_query = '(SELECT "{key}" FROM "{table}" LIMIT 0)' - null_fix_value_item = [select_type_query.format(key=k, table=db_table) for k in upd_keys_tuple] - null_fix_value_item.extend([op.get_null_fix_sql(model, name, conn) for name, op in key_fields_ops]) - values_items.append(null_fix_value_item) + # Prepare default values to insert into database, if they are not provided in updates or keys + # Dictionary keys list all db column names to be inserted. + if default_fds: + default_vals = [fd.get_field(model).get_default() for fd in default_fds] + defaults_sql_items, defaults_params = _generate_fds_sql(model, conn, default_fds, default_vals, True, True) + else: + defaults_sql_items = '' + defaults_params = [] + first = True for keys, updates in values.items(): - upd_item = [] - - # Prepare update fields values - for name in upd_keys_tuple: - val = updates[name] - f = model._meta.get_field(name) - set_func = set_functions[name] - item_sql, item_upd_params = set_func.format_field_value(f, val, conn) - values_update_params.extend(item_upd_params) - upd_item.append(item_sql) - - # Prepare key fields values - for (name, op), val in zip(key_fields_ops, keys): - f = model._meta.get_field(name) - item_sql, item_upd_params = op.format_field_value(f, val, conn) - values_update_params.extend(item_upd_params) - upd_item.append(item_sql) - - values_items.append(upd_item) + # For field sql and params + upd_values = [updates[fd.name] for fd in upd_fds] + upd_sql_items, upd_params = _generate_fds_sql(model, conn, upd_fds, upd_values, True, first) + key_sql_items, key_params = _generate_fds_sql(model, conn, key_fds, keys, False, first) + + sql_items = key_sql_items + upd_sql_items + if default_fds: + sql_items.extend(defaults_sql_items) + + values_items.append(sql_items) + values_update_params.extend(chain(key_params, upd_params, defaults_params)) + first = False + values_items_sql = ['(%s)' % ', '.join(item) for item in values_items] # NOTE. No extra brackets here or VALUES will return nothing values_sql = '%s' % ', '.join(values_items_sql) - # Form data for AS sel() section - # Names in key_fields can intersect with upd_keys_tuple and should be prefixed - sel_items = ["upd__%s" % field_name for field_name in upd_keys_tuple] - - # enumerate is used to prevent same keys for duplicated key_field names - sel_key_items = ["key_%d__%s" % (i, name) for i, (name, _) in enumerate(key_fields_ops)] + sel_sql = ', '.join([fd.prefixed_name for fd in chain(key_fds, upd_fds, default_fds)]) - sel_sql = ', '.join(sel_items + sel_key_items) + return tpl % (sel_sql, values_sql), values_update_params - return sel_key_items, tpl % (sel_sql, values_sql), values_update_params - -def _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_fields_ops): - # type: (Type[Model], List[str], DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> Tuple[str, List[Any]] +def _bulk_update_query_part(model, conn, key_fds, upd_fds): + # type: (Type[Model], DefaultConnectionProxy,Tuple[FieldDescriptor], Tuple[FieldDescriptor]) -> Tuple[str, List[Any]] """ Forms bulk update query part without values, counting that all keys and values are already in vals table :param model: Model to update, a subclass of django.db.models.Model - :param sel_key_items: Names of field in vals table. - Key fields are prefixed with key_%d__ - Values fields are prefixed with upd__ :param conn: Database connection used - :param set_functions: Functions to set values. - Should be a dict of field name as key, function class as value. - :param key_fields_ops: Key fields compare operators. - A tuple with (field_name from key_fields, operation name) elements :return: A tuple of sql and it's parameters """ - upd_keys_tuple = tuple(set_functions.keys()) # Query template. We will form its substitutes in next sections query = """ @@ -348,18 +355,17 @@ def _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_field # Form data for WHERE section # Remember that field names in sel table have prefixes. where_items = [] - for (key_field, op), sel_field in zip(key_fields_ops, sel_key_items): - table_field = '"t"."%s"' % model._meta.get_field(key_field).column - prefixed_sel_field = '"vals"."%s"' % sel_field - where_items.append(op.get_sql(table_field, prefixed_sel_field)) + for fd in key_fds: + table_field = '"t"."%s"' % fd.get_field(model).column + prefixed_sel_field = '"vals"."%s"' % fd.prefixed_name + where_items.append(fd.key_operator.get_sql(table_field, prefixed_sel_field)) where_sql = ' AND '.join(where_items) # Form data for SET section set_items, set_params = [], [] - for field_name in upd_keys_tuple: - func_cls = set_functions[field_name] - f = model._meta.get_field(field_name) - func_sql, params = func_cls.get_sql(f, '"vals"."upd__%s"' % field_name, conn, val_as_param=False) + for fd in upd_fds: + func_sql, params = fd.set_function.get_sql(fd.get_field(model), + '"vals"."%s"' % fd.prefixed_name, conn, val_as_param=False) set_items.append(func_sql) set_params.extend(params) set_sql = ', '.join(set_items) @@ -369,8 +375,8 @@ def _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_field return query, set_params -def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_ops): - # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, TSetFunctionsValid, TOperatorsValid) -> int +def _bulk_update_no_validation(model, values, conn, key_fds, upd_fds): + # type: (Type[Model], TUpdateValuesValid, DefaultConnectionProxy, Tuple[FieldDescriptor], Tuple[FieldDescriptor]) -> int """ Does bulk update, skipping parameters validation. It is used for speed up in bulk_update_or_create, where parameters are already formatted. @@ -378,27 +384,23 @@ def _bulk_update_no_validation(model, values, conn, set_functions, key_fields_op :param values: Data to update. All items must update same fields!!! Dict of key_values_tuple: update_fields_dict :param conn: Database connection used - :param set_functions: Functions to set values. - Should be a dict of field name as key, function class as value. - :param key_fields_ops: Key fields compare operators. - A tuple with (field_name from key_fields, operation name) elements :return: Number of records updated """ # No any values to update. Return that everything is done. - if not set_functions or not values: + if not upd_fds or not values: return len(values) - - sel_key_items, values_sql, values_params = _with_values_query_part(model, values, conn, set_functions, - key_fields_ops) - upd_sql, upd_params = _bulk_update_query_part(model, sel_key_items, conn, set_functions, key_fields_ops) + values_sql, values_params = _with_values_query_part(model, values, conn, key_fds, upd_fds) + upd_sql, upd_params = _bulk_update_query_part(model, conn, key_fds, upd_fds) # Execute query + logger.debug('EXECUTING STATEMENT:\n %sWITH PARAMETERS [%s]\n' + % (values_sql + upd_sql, ', '.join(str(v) for v in values_params + upd_params))) cursor = conn.cursor() cursor.execute(values_sql + upd_sql, params=values_params + upd_params) return cursor.rowcount -def bulk_update(model, values, key_fields='id', using=None, set_functions=None, key_fields_ops=(), +def bulk_update(model, values, key_fds='id', using=None, set_functions=None, key_fields_ops=(), batch_size=None, batch_delay=0): # type: (Type[Model], TUpdateValues, TFieldNames, Optional[str], TSetFunctions, TOperators, Optional[int], float) -> int """ @@ -413,7 +415,7 @@ def bulk_update(model, values, key_fields='id', using=None, set_functions=None, - key_values can be iterable or single object. - If iterable, key_values length must be equal to key_fields length. - If single object, key_fields is expected to have 1 element - :param key_fields: Field names, by which items would be selected. + :param key_fds: Field names, by which items would be selected. It can be a string, if there's only one key field or iterable of strings for multiple keys :param using: Database alias to make query to. :param set_functions: Functions to set values. @@ -433,29 +435,33 @@ def bulk_update(model, values, key_fields='id', using=None, set_functions=None, :return: Number of records updated """ # Validate data - assert inspect.isclass(model), "model must be django.db.models.Model subclass" - assert issubclass(model, Model), "model must be django.db.models.Model subclass" - assert using is None or isinstance(using, six.string_types) and using in connections, \ - "using parameter must be None or existing database alias" - - key_fields = _validate_field_names("key_fields", key_fields) - upd_keys_tuple, values = _validate_update_values(key_fields, values) + if not inspect.isclass(model): + raise TypeError("model must be django.db.models.Model subclass") + if not issubclass(model, Model): + raise TypeError("model must be django.db.models.Model subclass") + if using is not None and not isinstance(using, six.string_types): + raise TypeError("using parameter must be None or string") + if using and using not in connections: + raise ValueError("using parameter must be existing database alias") + + key_fds = _validate_field_names(key_fds) + upd_fds, values = _validate_update_values(key_fds, values) if len(values) == 0: return 0 - key_fields_ops = _validate_operators(key_fields, key_fields_ops) - set_functions = _validate_set_functions(model, upd_keys_tuple, set_functions) + key_fds = _validate_operators(key_fds, key_fields_ops) + upd_fds = _validate_set_functions(model, upd_fds, set_functions) conn = connection if using is None else connections[using] batched_result = batched_operation(_bulk_update_no_validation, values, - args=(model, None, conn, set_functions, key_fields_ops), + args=(model, None, conn, key_fds, upd_fds), data_arg_index=1, batch_size=batch_size, batch_delay=batch_delay) return sum(batched_result) -def _bulk_update_or_create_no_validation(model, values, key_fields, using, set_functions, update): - # type: (Type[Model], TUpdateValues, TFieldNames, Optional[str], TSetFunctions, bool) -> Tuple[int, int] +def _bulk_update_or_create_no_validation(model, values, key_fds, upd_fds, using, update): + # type: (Type[Model], TUpdateValues, Tuple[FieldDescriptor], Tuple[FieldDescriptor], Optional[str], bool) -> int """ Searches for records, given in values by key_fields. If records are found, updates them from values. If not found - creates them from values. Note, that all fields without default value must be present in values. @@ -466,7 +472,7 @@ def _bulk_update_or_create_no_validation(model, values, key_fields, using, set_f - key_values can be iterable or single object. - If iterable, key_values length must be equal to key_fields length. - If single object, key_fields is expected to have 1 element - :param key_fields: Field names, by which items would be selected (tuple) + :param key_fds: Field names, by which items would be selected (tuple) :param using: Database alias to make query to. :param set_functions: Functions to set values. Should be a dict of field name as key, function as value. @@ -478,15 +484,12 @@ def _bulk_update_or_create_no_validation(model, values, key_fields, using, set_f """ conn = connection if using is None else connections[using] - # bulk_update_or_create searches values by key equality only. No difficult filters here - key_fields_ops = tuple((key, EqualClauseOperator()) for key in key_fields) - with transaction.atomic(using=using): # Find existing values key_items = list(values.keys()) - qs = model.objects.filter(pdnf_clause(key_fields, key_items)).using(using).select_for_update() + qs = model.objects.filter(pdnf_clause([fd.name for fd in key_fds], key_items)).using(using).select_for_update() existing_values_dict = { - tuple([item[key] for key in key_fields]): item + tuple([item[fd.name] for fd in key_fds]): item for item in qs.values() } @@ -499,26 +502,139 @@ def _bulk_update_or_create_no_validation(model, values, key_fields, using, set_f update_items[key] = updates else: # Form a list of model objects for bulk_create() method - kwargs = dict(zip(key_fields, key)) - kwargs.update(updates) + # Insert on conflict and bulk update should work in a same way. + # So key values will be prior over update on insert + kwargs = updates + kwargs.update(dict(zip([fd.name for fd in key_fds], key))) - for k, sf in set_functions.items(): - sf.modify_create_params(model, k, kwargs) + for fd in upd_fds: + fd.set_function.modify_create_params(model, fd.name, kwargs) create_items.append(model(**kwargs)) # Update existing records - updated = _bulk_update_no_validation(model, update_items, conn, set_functions, key_fields_ops) + updated = _bulk_update_no_validation(model, update_items, conn, key_fds, upd_fds) - # Create abscent records + # Create absent records created = len(model.objects.db_manager(using).bulk_create(create_items)) - return created, updated + return created + updated + + +def _insert_on_conflict_query_part(model, conn, key_fds, upd_fds, default_fds, update): + # type: (Type[Model], DefaultConnectionProxy, Tuple[FieldDescriptor], Tuple[FieldDescriptor], Tuple[FieldDescriptor], bool) -> Tuple[str, List[Any]] + """ + Forms bulk update query part without values, counting that all keys and values are already in vals table + :param model: Model to update, a subclass of django.db.models.Model + :param sel_key_items: Names of field in vals table. + Key fields are prefixed with key_%d__ + Values fields are prefixed with upd__ + :param conn: Database connection used + :param set_functions: Functions to set values. + Should be a dict of field name as key, function class as value. + :param update: If this flag is not set, existing records will not be updated + :return: A tuple of sql and it's parameters + """ + query = """ + INSERT INTO %s (%s) + SELECT %s FROM vals + ON CONFLICT (%s) %s + """ + + # Table we save data to + db_table = model._meta.db_table + + # Form update data. It would be used in SET section, if values updated and INSERT section if created + set_items, set_params = [], [] + set_columns = [] + for fd in upd_fds: + set_columns.append('"%s"' % fd.get_field(model).column) + func_sql, params = fd.set_function.get_sql_value(fd.get_field(model), '"vals"."%s"' % fd.prefixed_name, conn, + val_as_param=False, with_table=True) + set_items.append(func_sql) + set_params.extend(params) + + where_columns = [] + where_items = [] + for fd in key_fds: + where_columns.append('"%s"' % fd.prefixed_name) + where_items.append('EXCLUDED."%s"' % fd.get_field(model).column) + + set_sql = '(%s) = (SELECT %s FROM "vals" WHERE (%s) = (%s))' \ + % (', '.join(set_columns), ', '.join(set_items), ', '.join(where_columns), ', '.join(where_items)) + + if update and upd_fds: + conflict_action ='DO UPDATE SET %s' % set_sql + conflict_action_params = set_params + else: + conflict_action = 'DO NOTHING' + conflict_action_params = [] + # Columns to insert to table + #upd_fields = {fd.get_field(model) for fd in upd_fds} + #insert_fds = list(chain([fd for fd in key_fds if fd.get_field(model) not in upd_fields], upd_fds, default_fds)) + key_fields = {fd.get_field(model) for fd in key_fds} + insert_fds = list(chain(key_fds, [fd for fd in upd_fds if fd.get_field(model) not in key_fields], default_fds)) -def bulk_update_or_create(model, values, key_fields='id', using=None, set_functions=None, update=True, batch_size=None, - batch_delay=0): - # type: (Type[Model], TUpdateValues, TFieldNames, Optional[str], TSetFunctions, bool, Optional[int], float) -> Tuple[int, int] + columns = ['"%s"' % fd.get_field(model).column for fd in insert_fds] + columns = ', '.join(columns) + + # Columns to select from values + val_columns, val_columns_params = [], [] + for fd in insert_fds: + val = '"vals"."%s"' % fd.prefixed_name + func_sql, params = fd.set_function.get_sql_value(fd.get_field(model), val, conn, val_as_param=False, + for_update=False) + val_columns.append(func_sql) + val_columns_params.extend(params) + val_columns = ', '.join(val_columns) + + # Conflict columns + key_columns = ', '.join(['"%s"' % fd.get_field(model).column for fd in key_fds]) + + sql = query % (db_table, columns, val_columns, key_columns, conflict_action) + return sql, val_columns_params + conflict_action_params + + +def _insert_on_conflict_no_validation(model, values, key_fds, upd_fds, using, update): + # type: (Type[Model], TUpdateValues, Tuple[FieldDescriptor], Tuple[FieldDescriptor], Optional[str], bool) -> int + """ + Searches for records, given in values by key_fields. If records are found, updates them from values. + If not found - creates them from values. Note, that all fields without default value must be present in values. + + :param model: Model to update, a subclass of django.db.models.Model + :param values: Data to update. All items must update same fields!!! + Dict of key_values: update_fields_dict + - key_values can be iterable or single object. + - If iterable, key_values length must be equal to key_fields length. + - If single object, key_fields is expected to have 1 element + :param key_fields: Field names, by which items would be selected (tuple) + :param using: Database alias to make query to. + :param set_functions: Functions to set values. + Should be a dict of field name as key, function as value. + Default function is eq. + Functions: [eq, =; incr, +; concat, ||] + Example: {'name': 'eq', 'int_fields': 'incr'} + :param update: If this flag is not set, existing records will not be updated + :return: A tuple (number of records created, number of records updated) + """ + conn = connection if using is None else connections[using] + + default_fds = _get_default_fds(model, tuple(chain(key_fds, upd_fds))) + val_sql, val_params = _with_values_query_part(model, values, conn, key_fds, upd_fds, default_fds) + upd_sql, upd_params = _insert_on_conflict_query_part(model, conn, key_fds, upd_fds, default_fds, update) + + # Execute query + logger.debug('EXECUTING STATEMENT:\n %sWITH PARAMETERS [%s]\n' + % (val_sql + upd_sql, ', '.join(str(v) for v in val_params + upd_params))) + cursor = conn.cursor() + cursor.execute(val_sql + upd_sql, params=val_params + upd_params) + return cursor.rowcount + + +def bulk_update_or_create(model, values, key_fields='id', using=None, set_functions=None, update=True, + key_is_unique=True, batch_size=None, batch_delay=0): + # type: (Type[Model], TUpdateValues, TFieldNames, Optional[str], TSetFunctions, bool, bool, Optional[int], float) -> int """ Searches for records, given in values by key_fields. If records are found, updates them from values. If not found - creates them from values. Note, that all fields without default value must be present in values. @@ -541,28 +657,48 @@ def bulk_update_or_create(model, values, key_fields='id', using=None, set_functi Functions: [eq, =; incr, +; concat, ||] Example: {'name': 'eq', 'int_fields': 'incr'} :param update: If this flag is not set, existing records will not be updated + :param key_is_unique: Settings this flag to False forces library to use 3-query transactional update, + not INSERT ... ON CONFLICT. :param batch_size: Optional. If given, data is split it into batches of given size. Each batch is queried independently. :param batch_delay: Delay in seconds between batches execution, if batch_size is not None. - :return: A tuple (number of records created, number of records updated) + :return: Number of records created or updated """ # Validate data - assert inspect.isclass(model), "model must be django.db.models.Model subclass" - assert issubclass(model, Model), "model must be django.db.models.Model subclass" - assert using is None or isinstance(using, six.string_types) and using in connections, \ - "using parameter must be None or existing database alias" - assert type(update) is bool, "update parameter must be boolean" - - key_fields = _validate_field_names("key_fields", key_fields) - upd_keys_tuple, values = _validate_update_values(key_fields, values) + if not inspect.isclass(model): + raise TypeError("model must be django.db.models.Model subclass") + if not issubclass(model, Model): + raise TypeError("model must be django.db.models.Model subclass") + if using is not None and not isinstance(using, six.string_types): + raise TypeError("using parameter must be None or existing database alias") + if using is not None and using not in connections: + raise ValueError("using parameter must be None or existing database alias") + if type(update) is not bool: + raise TypeError("update parameter must be boolean") + if type(key_is_unique) is not bool: + raise TypeError("key_is_unique must be boolean") + + key_fds = _validate_field_names(key_fields) + + # Add prefix to all descriptors + for i, f in enumerate(key_fds): + f.set_prefix('key', index=i) + + upd_fds, values = _validate_update_values(key_fds, values) if len(values) == 0: - return 0, 0 + return 0 - set_functions = _validate_set_functions(model, upd_keys_tuple, set_functions) + upd_fds = _validate_set_functions(model, upd_fds, set_functions) - batched_result = batched_operation(_bulk_update_or_create_no_validation, values, - args=(model, None, key_fields, using, set_functions, update), + # Insert on conflict is supported in PostgreSQL 9.5 and only with constraint + if get_postgres_version(using=using) > (9, 4) and key_is_unique: + batch_func = _insert_on_conflict_no_validation + else: + batch_func = _bulk_update_or_create_no_validation + + batched_result = batched_operation(batch_func, values, + args=(model, None, key_fds, upd_fds, using, update), data_arg_index=1, batch_size=batch_size, batch_delay=batch_delay) - return tuple(int(sum(item)) for item in zip(*batched_result)) + return sum(batched_result) diff --git a/src/django_pg_bulk_update/set_functions.py b/src/django_pg_bulk_update/set_functions.py index 4fc6d9c..17729ed 100644 --- a/src/django_pg_bulk_update/set_functions.py +++ b/src/django_pg_bulk_update/set_functions.py @@ -81,17 +81,18 @@ class AbstractSetFunction(object): # Otherwise a set of class names supported supported_field_classes = None - def format_field_value(self, field, val, connection, **kwargs): - # type: (Field, Any, DefaultConnectionProxy, **Any) -> Tuple[str, Tuple[Any]] + def format_field_value(self, field, val, connection, cast_type=False, **kwargs): + # type: (Field, Any, DefaultConnectionProxy, bool, **Any) -> Tuple[str, Tuple[Any]] """ Formats value, according to field rules :param field: Django field to take format from :param val: Value to format :param connection: Connection used to update data + :param cast_type: Adds type casting to sql if flag is True :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ - return format_field_value(field, val, connection) + return format_field_value(field, val, connection, cast_type=cast_type) def modify_create_params(self, model, key, kwargs): # type: (Type[Model], str, Dict[str, Any]) -> Dict[str, Any] @@ -111,23 +112,25 @@ def modify_create_params(self, model, key, kwargs): return kwargs - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): - # type: (Field, Any, DefaultConnectionProxy, bool, **Any) -> Tuple[str, Tuple[Any]] + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): + # type: (Field, Any, DefaultConnectionProxy, bool, bool, bool, **Any) -> Tuple[str, Tuple[Any]] """ Returns value sql to set into field and parameters for query execution This method is called from get_sql() by default. + :param with_table: If flag is set, column name in sql is prefixed by table name :param field: Django field to take format from :param val: Value to format :param connection: Connection used to update data :param val_as_param: If flag is not set, value should be converted to string and inserted into query directly. Otherwise a placeholder and query parameter will be used + :param for_update: If flag is set, returns update sql. Otherwise - insert SQL :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ raise NotImplementedError("'%s' must define get_sql method" % self.__class__.__name__) - def get_sql(self, field, val, connection, val_as_param=True, **kwargs): - # type: (Field, Any, DefaultConnectionProxy, bool, **Any) -> Tuple[str, Tuple[Any]] + def get_sql(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): + # type: (Field, Any, DefaultConnectionProxy, bool, bool, bool, **Any) -> Tuple[str, Tuple[Any]] """ Returns function sql and parameters for query execution :param field: Django field to take format from @@ -135,10 +138,13 @@ def get_sql(self, field, val, connection, val_as_param=True, **kwargs): :param connection: Connection used to update data :param val_as_param: If flag is not set, value should be converted to string and inserted into query directly. Otherwise a placeholder and query parameter will be used + :param with_table: If flag is set, column name in sql is prefixed by table name + :param for_update: If flag is set, returns update sql. Otherwise - insert SQL :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ - val, params = self.get_sql_value(field, val, connection, val_as_param=val_as_param, **kwargs) + val, params = self.get_sql_value(field, val, connection, val_as_param=val_as_param, with_table=with_table, + for_update=for_update, **kwargs) return '"%s" = %s' % (field.column, val), params @classmethod @@ -151,7 +157,7 @@ def get_function_by_name(cls, name): # type: (str) -> Optional[Type[AbstractSet try: return next(sub_cls for sub_cls in get_subclasses(cls, recursive=True) if name in sub_cls.names) except StopIteration: - raise AssertionError("Operation with name '%s' doesn't exist" % name) + raise ValueError("Operation with name '%s' doesn't exist" % name) def field_is_supported(self, field): # type: (Field) -> bool """ @@ -184,11 +190,22 @@ def _parse_null_default(self, field, connection, **kwargs): return self.format_field_value(field, null_default, connection) + def _get_field_column(self, field, with_table=False): + # type: (Field, bool) -> str + """ + Returns quoted field column, prefixed with table name if needed + :param field: Field instance + :param with_table: Boolean flag - add table or not + :return: String name + """ + table = '"%s".' % field.model._meta.db_table if with_table else '' + return '%s"%s"' % (table, field.column) + class EqualSetFunction(AbstractSetFunction): names = {'eq', '='} - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): if val_as_param: return self.format_field_value(field, val, connection) else: @@ -204,13 +221,17 @@ def modify_create_params(self, model, key, kwargs): return kwargs - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): - tpl = 'COALESCE(%s, "%s")' + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): + tpl = 'COALESCE(%s, %s)' + if for_update: + default_value, default_params = self._get_field_column(field, with_table=with_table), [] + else: + default_value, default_params = self.format_field_value(field, field.get_default(), connection) if val_as_param: sql, params = self.format_field_value(field, val, connection) - return tpl % (sql, field.column), params + return tpl % (sql, default_value), params + default_params else: - return tpl % (str(val), field.column), [] + return tpl % (str(val), default_value), default_params class PlusSetFunction(AbstractSetFunction): @@ -221,15 +242,20 @@ class PlusSetFunction(AbstractSetFunction): 'IntegerRangeField', 'BigIntegerRangeField', 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField'} - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): null_default, null_default_params = self._parse_null_default(field, connection, **kwargs) - tpl = 'COALESCE("%s", %s) + %s' if val_as_param: sql, params = self.format_field_value(field, val, connection) - return tpl % (field.column, null_default, sql), null_default_params + params else: - return tpl % (field.column, null_default, str(val)), null_default_params + sql, params = str(val), [] + + if for_update: + tpl = 'COALESCE(%s, %s) + %s' + return tpl % (self._get_field_column(field, with_table=with_table), null_default, sql),\ + null_default_params + params + else: + return sql, params class ConcatSetFunction(AbstractSetFunction): @@ -239,22 +265,28 @@ class ConcatSetFunction(AbstractSetFunction): 'URLField', 'BinaryField', 'JSONField', 'ArrayField', 'CITextField', 'CICharField', 'CIEmailField'} - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): null_default, null_default_params = self._parse_null_default(field, connection, **kwargs) # Postgres 9.4 has JSONB support, but doesn't support concat operator (||) # So I've taken function to solve the problem from # Note, that function should be created before using this operator - if get_postgres_version(as_tuple=False) < 90500 and isinstance(field, JSONField): - tpl = '{0}(COALESCE("%s", %s), %s)'.format(Postgres94MergeJSONBMigration.FUNCTION_NAME) + if not for_update: + tpl = '%s' + elif get_postgres_version(as_tuple=False) < 90500 and isinstance(field, JSONField): + tpl = '{0}(COALESCE(%s, %s), %s)'.format(Postgres94MergeJSONBMigration.FUNCTION_NAME) else: - tpl = 'COALESCE("%s", %s) || %s' + tpl = 'COALESCE(%s, %s) || %s' if val_as_param: - sql, params = self.format_field_value(field, val, connection) - return tpl % (field.column, null_default, sql), null_default_params + params + val_sql, params = self.format_field_value(field, val, connection) else: - return tpl % (field.column, null_default, str(val)), null_default_params + val_sql, params = str(val), [] + + if not for_update: + return tpl % val_sql, params + else: + return tpl % (self._get_field_column(field, with_table=with_table), null_default, val_sql), null_default_params + params class UnionSetFunction(AbstractSetFunction): @@ -262,8 +294,13 @@ class UnionSetFunction(AbstractSetFunction): supported_field_classes = {'ArrayField'} - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): - sub_func = ConcatSetFunction() - sql, params = sub_func.get_sql_value(field, val, connection, val_as_param=val_as_param, **kwargs) - sql = 'ARRAY(SELECT DISTINCT UNNEST(%s))' % sql + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): + if for_update: + sub_func = ConcatSetFunction() + sql, params = sub_func.get_sql_value(field, val, connection, val_as_param=val_as_param, + with_table=with_table, **kwargs) + sql = 'ARRAY(SELECT DISTINCT UNNEST(%s))' % sql + else: + sql, params = val, [] + return sql, params diff --git a/src/django_pg_bulk_update/types.py b/src/django_pg_bulk_update/types.py index 974b3ab..99d3a03 100644 --- a/src/django_pg_bulk_update/types.py +++ b/src/django_pg_bulk_update/types.py @@ -1,14 +1,117 @@ -from typing import Iterable, Union, Dict, Tuple, Any, Optional +from collections import defaultdict +from typing import Iterable, Union, Dict, Tuple, Any, Optional, Type -from .clause_operators import AbstractClauseOperator -from .set_functions import AbstractSetFunction +import six +from django.db.models import Model, Field + +from .clause_operators import AbstractClauseOperator, EqualClauseOperator +from .set_functions import AbstractSetFunction, EqualSetFunction TFieldNames = Union[str, Iterable[str]] -TOperatorsValid = Tuple[Tuple[str, AbstractClauseOperator]] +TOperator = Union[str, AbstractClauseOperator] +TOperatorsValid = Tuple['FieldDescriptor'] -TOperators = Union[Dict[str, Union[str, AbstractClauseOperator]], Iterable[Union[str, AbstractClauseOperator]]] +TOperators = Union[Dict[str, TOperator], Iterable[TOperator]] TUpdateValuesValid = Dict[Tuple[Any], Dict[str, Any]] TUpdateValues = Union[Union[TUpdateValuesValid, Dict[Any, Dict[str, Any]]], Iterable[Dict[str, Any]]] -TSetFunctions = Optional[Dict[str, Union[str, AbstractSetFunction]]] -TSetFunctionsValid = Dict[str, AbstractSetFunction] +TSetFunction = Union[str, AbstractSetFunction] +TSetFunctions = Optional[Dict[str, TSetFunction]] +TSetFunctionsValid = Tuple['FieldDescriptor'] + + +class FieldDescriptor(object): + """ + This class is added in order to make passing parameters in queries easier + """ + __slots__ = ['name', '_set_function', '_key_operator', '_prefix'] + + def __init__(self, name, set_function=None, key_operator=None): + # type: (str, TSetFunction, TOperator) -> None + self.name = name + self.set_function = set_function + self.key_operator = key_operator + self._prefix = '' + + def get_field(self, model): + # type: (Type[Model]) -> Field + """ + Returns model field, described by this descriptor + :param model: django.db.models.Model subclass + :return: django.db.fields.Field instance + """ + return model._meta.get_field(self.name) + + @property + def set_function(self): + # type: () -> AbstractSetFunction + """ + Returns set_function for this field descriptor. + :return: AbstractSetFunction instance + """ + return self._set_function + + @set_function.setter + def set_function(self, val): + # type: (Union[None, str, AbstractSetFunction]) -> None + """ + Changes set_function for this field_descriptor. + :param val: Set function name or instance. Defaults to EqualSetFunction() if None is passed + :return: + """ + if val is None: + self._set_function = EqualSetFunction() + elif isinstance(val, six.string_types): + self._set_function = AbstractSetFunction.get_function_by_name(val)() + elif isinstance(val, AbstractSetFunction): + self._set_function = val + else: + raise TypeError("Invalid set function type: %s" % str(type(val))) + + @property + def key_operator(self): + # type: () -> AbstractClauseOperator + """ + Returns operator to use in comparison + :return: AbstractKeyOperator instance + """ + return self._key_operator + + @key_operator.setter + def key_operator(self, val): + # type: (Union[None, str, AbstractClauseOperator]) -> None + """ + Sets comparison operator for this field descriptor + :param val: String name of operator or AbstractClauseOperator instance. + Defaults to EqualClauseOperator if None is passed + :return: None + """ + if val is None: + self._key_operator = EqualClauseOperator() + elif isinstance(val, six.string_types): + self._key_operator = AbstractClauseOperator.get_operator_by_name(val)() + elif isinstance(val, AbstractClauseOperator): + self._key_operator = val + else: + raise TypeError("Invalid key operator type: %s" % str(type(val))) + + def set_prefix(self, prefix, index=None): # type: (str, Optional[int]) -> None + """ + Sets prefix to use in values query part. It is used to divide key fields from update and default fields + :param prefix: Prefix to use + :param index: field can be used more than once in conditions. Set this index to prevent duplicates. + :return: + """ + self._prefix = prefix + if index is not None: + self._prefix += '_%d' % index + + @property + def prefixed_name(self): # type: () -> str + """ + Returns prefixed name of the field + :return: + """ + if self._prefix is None: + raise ValueError('prefix has not been set yet') + return "%s__%s" % (self._prefix, self.name) diff --git a/src/django_pg_bulk_update/utils.py b/src/django_pg_bulk_update/utils.py index ed8e88d..0c8fd3e 100644 --- a/src/django_pg_bulk_update/utils.py +++ b/src/django_pg_bulk_update/utils.py @@ -1,14 +1,17 @@ """ Contains some project unbind helpers """ +import logging from time import sleep from django.core.exceptions import FieldError from django.db import DefaultConnectionProxy from django.db.models import Field from django.db.models.sql.subqueries import UpdateQuery -from typing import TypeVar, Set, Any, Tuple, Iterable, Callable, Optional, List, Dict -from .compatibility import hstore_serialize, hstore_available, jsonb_available +from typing import TypeVar, Set, Any, Tuple, Iterable, Callable, Optional, List +from .compatibility import hstore_serialize, hstore_available, jsonb_available, get_field_db_type + +logger = logging.getLogger('django-pg-bulk-update') # JSONField is available in django 1.9+ only # I create fake class for previous version in order to just skip isinstance(item, JSONField) if branch @@ -46,13 +49,14 @@ def get_subclasses(cls, recursive=False): # type: (T, bool) -> Set[T] return subclasses -def format_field_value(field, val, conn): - # type: (Field, Any, DefaultConnectionProxy, **Any) -> Tuple[str, Tuple[Any]] +def format_field_value(field, val, conn, cast_type=False): + # type: (Field, Any, DefaultConnectionProxy, bool) -> Tuple[str, Tuple[Any]] """ Formats value, according to field rules :param field: Django field to take format from :param val: Value to format :param conn: Connection used to update data + :param cast_type: Adds type casting to sql if flag is True :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ # This content is a part, taken from django.db.models.sql.compiler.SQLUpdateCompiler.as_sql() @@ -116,6 +120,9 @@ def format_field_value(field, val, conn): else: value, update_params = 'NULL', tuple() + if cast_type: + value = 'CAST(%s AS %s)' % (value, get_field_db_type(field, conn)) + return value, update_params @@ -135,11 +142,18 @@ def batched_operation(handler, data, batch_size=None, batch_delay=0, args=(), kw Note, that args must contain any placeholder value, which will be replaced by batch data :return: A list of results for each batch """ - assert batch_size is None or type(batch_size) is int and batch_size > 0,\ - "batch_size must be positive integer if given" - assert type(batch_delay) in {int, float} and batch_delay >= 0, "batch_delay must be non negative float" - assert type(data_arg_index) is int and 0 <= data_arg_index < len(args),\ - "data_arg_num must be integer between 0 and len(args)" + if batch_size is not None and (type(batch_size) is not int): + raise TypeError("batch_size must be positive integer if given") + if batch_size is not None and batch_size <= 0: + raise ValueError("batch_size must be positive integer if given") + if type(batch_delay) not in {int, float}: + raise TypeError("batch_delay must be non negative float") + if batch_delay < 0: + raise ValueError("batch_delay must be non negative float") + if type(data_arg_index) is not int: + raise TypeError("data_arg_num must be integer between 0 and len(args)") + if not 0 <= data_arg_index < len(args): + raise ValueError("data_arg_num must be integer between 0 and len(args)") def _batches_iterator(): if batch_size is None: @@ -155,7 +169,8 @@ def _batches_iterator(): results = [] args = list(args) kwargs = kwargs or {} - for batch in _batches_iterator(): + for j, batch in enumerate(_batches_iterator()): + logger.debug('Processing batch %d with size %d' % (j + 1, len(batch))) args[data_arg_index] = batch r = handler(*args, **kwargs) results.append(r) diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index 02e2a7c..b89f03e 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -37,7 +37,8 @@ class Migration(migrations.Migration): options={ 'abstract': False, } - ) + ), + migrations.AlterUniqueTogether(name='TestModel', unique_together=[('id', 'name')]) ] diff --git a/tests/models.py b/tests/models.py index 9a903bb..34e55f2 100644 --- a/tests/models.py +++ b/tests/models.py @@ -6,11 +6,15 @@ from django_pg_bulk_update.manager import BulkUpdateManager from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available +class Meta: + unique_together = ['id', 'name'] + # Not all fields are available in different django and postgres versions model_attrs = { 'name': models.CharField(max_length=50, null=True, blank=True, default=''), 'int_field': models.IntegerField(null=True, blank=True), 'objects': BulkUpdateManager(), + 'Meta': Meta, '__module__': __name__ } diff --git a/tests/test_settings.py b/tests/settings.py similarity index 76% rename from tests/test_settings.py rename to tests/settings.py index 2c10ebf..903bf67 100644 --- a/tests/test_settings.py +++ b/tests/settings.py @@ -22,6 +22,21 @@ } } +LOGGING = { + 'version': 1, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + }, + }, + 'loggers': { + 'django-pg-bulk-update': { + 'handlers': ['console'], + 'level': 'DEBUG' + } + } +} + # DATABASES should be defined before this call from django_pg_bulk_update.compatibility import jsonb_available, array_available, hstore_available diff --git a/tests/test_bulk_update.py b/tests/test_bulk_update.py index ece5986..96c876b 100644 --- a/tests/test_bulk_update.py +++ b/tests/test_bulk_update.py @@ -12,35 +12,35 @@ class TestInputFormats(TestCase): fixtures = ['test_model'] def test_model(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(123, []) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update('123', []) def test_values(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, 123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, {(1, 2): {'id': 10}}) - with self.assertRaises(AssertionError): - bulk_update(TestModel, {1: {'id': 10}}, key_fields=('id', 'name')) + with self.assertRaises(ValueError): + bulk_update(TestModel, {1: {'id': 10}}, key_fds=('id', 'name')) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'name': 'test'}]) self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': 'abc'}])) self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': 'abc', 'int_field': 2}], - key_fields=('id', 'name'))) + key_fds=('id', 'name'))) self.assertEqual(1, bulk_update(TestModel, {1: {'name': 'abc'}})) self.assertEqual(1, bulk_update(TestModel, {(1,): {'name': 'abc'}})) - self.assertEqual(1, bulk_update(TestModel, {(2, 'test2'): {'int_field': 2}}, key_fields=('id', 'name'))) - self.assertEqual(1, bulk_update(TestModel, {('test3',): {'int_field': 2}}, key_fields='name')) + self.assertEqual(1, bulk_update(TestModel, {(2, 'test2'): {'int_field': 2}}, key_fds=('id', 'name'))) + self.assertEqual(1, bulk_update(TestModel, {('test3',): {'int_field': 2}}, key_fds='name')) def test_key_fields(self): values = [{ @@ -49,11 +49,11 @@ def test_key_fields(self): }] self.assertEqual(1, bulk_update(TestModel, values)) - self.assertEqual(1, bulk_update(TestModel, values, key_fields='id')) - self.assertEqual(1, bulk_update(TestModel, values, key_fields=['id'])) - self.assertEqual(1, bulk_update(TestModel, values, key_fields=['id', 'name'])) - self.assertEqual(1, bulk_update(TestModel, values, key_fields='name')) - self.assertEqual(1, bulk_update(TestModel, values, key_fields=['name'])) + self.assertEqual(1, bulk_update(TestModel, values, key_fds='id')) + self.assertEqual(1, bulk_update(TestModel, values, key_fds=['id'])) + self.assertEqual(1, bulk_update(TestModel, values, key_fds=['id', 'name'])) + self.assertEqual(1, bulk_update(TestModel, values, key_fds='name')) + self.assertEqual(1, bulk_update(TestModel, values, key_fds=['name'])) def test_using(self): values = [{ @@ -64,34 +64,30 @@ def test_using(self): self.assertEqual(1, bulk_update(TestModel, values)) self.assertEqual(1, bulk_update(TestModel, values, using='default')) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, values, using='invalid') - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, values, using=123) def test_set_functions(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions=123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions=[123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={1: 'test'}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'id': 1}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'invalid': 1}) - with self.assertRaises(AssertionError): - bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'int_field': 'invalid'}) - - # int_field is not in update keys here, set_function - with self.assertRaises(AssertionError): - self.assertEqual(1, bulk_update(TestModel, [{'id': 2, 'name': 'test1'}], set_functions={'int_field': '+'})) + with self.assertRaises(ValueError): + bulk_update(TestModel, [{'id': 1, 'int_field': 1}], set_functions={'int_field': 'invalid'}) # I don't test all set functions here, as there is another TestCase for this: TestSetFunctions self.assertEqual(1, bulk_update(TestModel, [{'id': 2, 'name': 'test1'}], @@ -99,29 +95,29 @@ def test_set_functions(self): self.assertEqual(1, bulk_update(TestModel, [{'id': 2, 'name': 'test1'}], set_functions={'name': '||'})) def test_key_fields_ops(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], key_fields_ops=123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], key_fields_ops=[123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], key_fields_ops={123: 'test'}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], key_fields_ops={'id': 'invalid'}) # name is not in key_fields - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fields_ops={'name': 'in'}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fields_ops={'name': 123}) self.assertEqual(1, bulk_update(TestModel, [{'id': [1], 'name': 'test1'}], key_fields_ops={'id': 'in'})) - self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fields='name', + self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fds='name', key_fields_ops={'name': 'in'})) - self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fields='name', + self.assertEqual(1, bulk_update(TestModel, [{'id': 1, 'name': ['test1']}], key_fds='name', key_fields_ops=['in'])) self.assertEqual(1, bulk_update(TestModel, [{'id': [1], 'name': 'test1'}], key_fields_ops=['in'])) self.assertEqual(1, bulk_update(TestModel, [{'id': [1], 'name': 'test1'}], key_fields_ops=[InClauseOperator()])) @@ -129,19 +125,19 @@ def test_key_fields_ops(self): key_fields_ops={'id': InClauseOperator()})) def test_batch(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], batch_size='abc') - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=-2) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=2.5) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=1, batch_delay='abc') - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=1, batch_delay=-2) @@ -207,7 +203,7 @@ def test_key_update(self): 'id': 8, 'name': 'bulk_update_8' } - }, key_fields='name') + }, key_fds='name') self.assertEqual(3, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk in {1, 5, 8}: @@ -272,7 +268,7 @@ def test_same_key_fields(self): (6, 8): { "name": "second" } - }, key_fields=('id', 'id'), key_fields_ops=('>=', '<')) + }, key_fds=('id', 'id'), key_fields_ops=('>=', '<')) self.assertEqual(4, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk in {1, 2}: @@ -305,7 +301,7 @@ def test_example(self): "updated2": { "int_field": 3 } - }, key_fields="name") + }, key_fds="name") self.assertEqual(2, updated) self.assertListEqual([ {"id": 1, "name": "updated1", "int_field": 2}, @@ -318,7 +314,7 @@ def test_example(self): "int_field": 3, "name": "incr" } - }, key_fields=['id', 'int_field'], key_fields_ops={'int_field': '<', 'id': 'gte'}, + }, key_fds=['id', 'int_field'], key_fields_ops={'int_field': '<', 'id': 'gte'}, set_functions={'int_field': '+'}) self.assertEqual(1, updated) self.assertListEqual([ @@ -493,7 +489,7 @@ def test_in(self): }, { 'int_field': 2, 'name': ['2'] - }], key_fields='name', key_fields_ops=['in']) + }], key_fds='name', key_fields_ops=['in']) self.assertEqual(6, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk in {1, 2, 3}: diff --git a/tests/test_bulk_update_or_create.py b/tests/test_bulk_update_or_create.py index dd4a57a..e780ed6 100644 --- a/tests/test_bulk_update_or_create.py +++ b/tests/test_bulk_update_or_create.py @@ -12,42 +12,43 @@ class TestInputFormats(TestCase): fixtures = ['test_model'] def test_model(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(123, []) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create('123', []) def test_values(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, 123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, {(1, 2): {'id': 10}}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, {1: {'id': 10}}, key_fields=('id', 'name')) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'name': 'test'}]) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, [{'id': 1, 'name': 'abc'}, {'id': 21, 'name': 'create'}])) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, [{'id': 1, 'name': 'abc', 'int_field': 2}, {'id': 20, 'name': 'abc', 'int_field': 3}], key_fields=('id', 'name'))) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, {1: {'name': 'abc'}, 19: {'name': 'created'}})) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, {(1,): {'name': 'abc'}, (18,): {'name': 'created'}})) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, {(2, 'test2'): {'int_field': 2}, (17, 'test2'): {'int_field': 4}}, key_fields=('id', 'name'))) - self.assertTupleEqual((1, 1), bulk_update_or_create( - TestModel, {('test33',): {'int_field': 2}, ('test3',): {'int_field': 2}}, key_fields='name')) + self.assertEqual(2, bulk_update_or_create( + TestModel, {('test33',): {'int_field': 2}, ('test3',): {'int_field': 2}}, key_fields='name', + key_is_unique=False)) def test_key_fields(self): values = [{ @@ -58,19 +59,19 @@ def test_key_fields(self): 'name': 'bulk_update_or_create_2' }] - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values)) + self.assertEqual(2, bulk_update_or_create(TestModel, values)) values[1]['id'] += 1 - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, key_fields='id')) + self.assertEqual(2, bulk_update_or_create(TestModel, values, key_fields='id')) values[1]['id'] += 1 - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, key_fields=['id'])) + self.assertEqual(2, bulk_update_or_create(TestModel, values, key_fields=['id'])) values[1]['id'] += 1 - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, key_fields=['id', 'name'])) + self.assertEqual(1, bulk_update_or_create(TestModel, values, key_fields=['id', 'name'])) values[1]['id'] += 1 values[1]['name'] += '1' - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, key_fields='name')) + self.assertEqual(2, bulk_update_or_create(TestModel, values, key_fields='name', key_is_unique=False)) values[1]['id'] += 1 values[1]['name'] += '1' - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, key_fields=['name'])) + self.assertEqual(2, bulk_update_or_create(TestModel, values, key_fields=['name'], key_is_unique=False)) def test_using(self): values = [{ @@ -81,74 +82,69 @@ def test_using(self): 'name': 'bulk_update_or_create_2' }] - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values)) + self.assertEqual(2, bulk_update_or_create(TestModel, values)) values[1]['id'] += 1 - self.assertTupleEqual((1, 1), bulk_update_or_create(TestModel, values, using='default')) + self.assertEqual(2, bulk_update_or_create(TestModel, values, using='default')) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, values, using='invalid') - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, values, using=123) def test_set_functions(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions=123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions=[123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={1: 'test'}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'id': 1}) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'invalid': 1}) - with self.assertRaises(AssertionError): - bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], set_functions={'int_field': 'invalid'}) - - # int_field is not in update keys here, set_function - with self.assertRaises(AssertionError): - self.assertEqual(1, bulk_update_or_create(TestModel, [{'id': 2, 'name': 'test1'}], - set_functions={'int_field': '+'})) + with self.assertRaises(ValueError): + bulk_update_or_create(TestModel, [{'id': 1, 'int_field': 1}], set_functions={'int_field': 'invalid'}) # I don't test all set functions here, as there is another TestCase for this: TestSetFunctions - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, [{'id': 2, 'name': 'test1'}, {'id': 10, 'name': 'test1'}], set_functions={'name': ConcatSetFunction()} )) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, [{'id': 2, 'name': 'test1'}, {'id': 11, 'name': 'test1'}], set_functions={'name': '||'})) def test_update(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], update=123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': ['test1']}], update='123') - self.assertTupleEqual((1, 0), bulk_update_or_create( + self.assertEqual(1, bulk_update_or_create( TestModel, [{'id': 1, 'name': 'test30'}, {'id': 20, 'name': 'test30'}], update=False)) - self.assertTupleEqual((1, 1), bulk_update_or_create( + self.assertEqual(2, bulk_update_or_create( TestModel, [{'id': 1, 'name': 'test30'}, {'id': 19, 'name': 'test30'}], update=True)) def test_batch(self): - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], batch_size='abc') - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=-2) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=2.5) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=1, batch_delay='abc') - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): bulk_update_or_create(TestModel, [{'id': 1, 'name': 'test1'}], batch_size=1, batch_delay=-2) @@ -167,7 +163,7 @@ def test_update(self): 'id': 11, 'name': 'bulk_update_11' }]) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().count()) @@ -185,7 +181,7 @@ def test_update(self): def test_empty(self): res = bulk_update_or_create(TestModel, []) - self.assertTupleEqual((0, 0), res) + self.assertEqual(0, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): self.assertEqual('test%d' % pk, name) self.assertEqual(pk, int_field) @@ -198,7 +194,7 @@ def test_quotes(self): 'id': 11, 'name': '"' }]) - self.assertTupleEqual((1, 1), res) + self.assertEqual(2, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk == 1: self.assertEqual('\'', name) @@ -214,22 +210,23 @@ def test_quotes(self): def test_key_update(self): res = bulk_update_or_create(TestModel, { - ('test1',): { + (1, 'test1'): { 'id': 1, 'name': 'bulk_update_1' }, - ('test5',): { + (5, 'test5'): { 'id': 5, 'name': 'bulk_update_5' }, - ('bulk_update_11',): { + (11, 'test11'): { 'id': 11, 'name': 'bulk_update_11' } - }, key_fields='name') - self.assertTupleEqual((1, 2), res) + }, key_fields=('id', 'name')) + self.assertEqual(3, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): - if pk in {1, 5, 11}: + if pk in {1, 5}: + # Note due to insert on conflict restrictions key fields will be prior to update ones on insert. self.assertEqual('bulk_update_%d' % pk, name) else: self.assertEqual('test%d' % pk, name) @@ -250,7 +247,7 @@ def test_using(self): 'id': 11, 'name': 'bulk_update_11' }], using='secondary') - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().using('secondary').count()) @@ -283,7 +280,7 @@ def test_batch(self): 'id': 11, 'name': 'bulk_update_11' }], batch_size=1) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().count()) @@ -301,7 +298,7 @@ def test_batch(self): # Test for empty values correct res = bulk_update_or_create(TestModel, [], batch_size=10) - self.assertTupleEqual((0, 0), res) + self.assertEqual(0, res) class TestReadmeExample(TestCase): @@ -313,7 +310,7 @@ def test_example(self): TestModel(pk=3, name="incr", int_field=4), ]) - inserted, updated = bulk_update_or_create(TestModel, [{ + res = bulk_update_or_create(TestModel, [{ "id": 3, "name": "_concat1", "int_field": 4 @@ -322,8 +319,7 @@ def test_example(self): "name": "concat2", "int_field": 5 }], set_functions={'name': '||'}) - self.assertEqual(1, inserted) - self.assertEqual(1, updated) + self.assertEqual(2, res) self.assertListEqual([ {"id": 1, "name": "updated1", "int_field": 2}, @@ -347,7 +343,7 @@ def test_incr(self): 'id': 11, 'int_field': 11 }], set_functions={'int_field': '+'}) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk in {1, 5}: self.assertEqual(2 * pk, int_field) @@ -370,7 +366,7 @@ def test_concat_str(self): 'id': 11, 'name': 'bulk_update_11' }], set_functions={'name': '||'}) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'): if pk in {1, 5}: self.assertEqual('test%dbulk_update_%d' % (pk, pk), name) @@ -385,7 +381,7 @@ def test_concat_str(self): self.assertIsNone(int_field) def _test_concat_array(self, iteration, res): - self.assertTupleEqual((1, 3) if iteration == 1 else (0, 4), res) + self.assertEqual(4, res) for pk, name, array_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'array_field'): if pk in {1, 2, 11}: self.assertListEqual([pk] * iteration, array_field) @@ -409,7 +405,7 @@ def test_concat_array(self): self._test_concat_array(i, res) def _test_union_array(self, iteration, res): - self.assertTupleEqual((1, 3) if iteration == 1 else (0, 4), res) + self.assertEqual(4, res) for pk, name, array_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'array_field'): if pk == 1: self.assertListEqual([pk], array_field) @@ -437,7 +433,7 @@ def test_union_array(self): self._test_union_array(i, res) def _test_concat_dict(self, iteration, res, field_name, val_as_str=False): - self.assertTupleEqual((1, 3) if iteration == 1 else (0, 4), res) + self.assertEqual(4, res) for pk, name, dict_field in TestModel.objects.all().order_by('id').values_list('id', 'name', field_name): if pk in {1, 2, 11}: # Note that JSON standard uses only strings as keys. So json.dumps will convert it @@ -483,7 +479,7 @@ def test_eq_not_null(self): 'id': 11, 'name': None }], set_functions={'name': 'eq_not_null'}) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().count()) @@ -518,7 +514,7 @@ def test_bulk_update_or_create(self): 'id': 11, 'name': 'bulk_update_11' }]) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().count()) @@ -545,7 +541,7 @@ def test_using(self): 'id': 11, 'name': 'bulk_update_11' }]) - self.assertTupleEqual((1, 2), res) + self.assertEqual(3, res) # 9 from fixture + 1 created self.assertEqual(10, TestModel.objects.all().using('secondary').count()) @@ -577,7 +573,7 @@ def test_array(self): {'id': 2, 'array_field': [2]}, {'id': 11, 'array_field': [11]}, {'id': 4, 'array_field': []}]) - self.assertTupleEqual((1, 3), res) + self.assertEqual(4, res) for pk, name, array_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'array_field'): if pk in {1, 2, 11}: self.assertListEqual([pk], array_field) @@ -598,7 +594,7 @@ def test_jsonb(self): {'id': 11, 'json_field': {'test': '11'}}, {'id': 4, 'json_field': {}}, {'id': 5, 'json_field': {'single': "'", "multi": '"'}}]) - self.assertTupleEqual((1, 4), res) + self.assertEqual(5, res) for pk, name, json_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'json_field'): if pk in {1, 2, 11}: self.assertDictEqual({'test': str(pk)}, json_field) @@ -621,7 +617,7 @@ def test_hstore(self): {'id': 11, 'hstore_field': {'test': '11'}}, {'id': 4, 'hstore_field': {}}, {'id': 5, 'hstore_field': {'single': "'", "multi": '"'}}]) - self.assertTupleEqual((1, 4), res) + self.assertEqual(5, res) for item in TestModel.objects.all().order_by('id'): if item.pk in {1, 2, 11}: self.assertDictEqual({'test': str(item.pk)}, item.hstore_field) diff --git a/tests/test_pdnf_clause.py b/tests/test_pdnf_clause.py index d07ec49..904374a 100644 --- a/tests/test_pdnf_clause.py +++ b/tests/test_pdnf_clause.py @@ -10,33 +10,33 @@ class PDNFClauseTest(TestCase): def test_assertions(self): # field_names - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): pdnf_clause(123, []) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): pdnf_clause([123], []) # field_values - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): pdnf_clause(['id'], 123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): pdnf_clause(['id'], [123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): pdnf_clause(['id'], [{'invalid': 123}]) - # Operations - with self.assertRaises(AssertionError): + # Operators + with self.assertRaises(TypeError): pdnf_clause(['id'], [], key_fields_ops=123) - with self.assertRaises(AssertionError): + with self.assertRaises(TypeError): pdnf_clause(['id'], [], key_fields_ops=[123]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): pdnf_clause(['id'], [], key_fields_ops=["invalid"]) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): pdnf_clause(['id'], [], key_fields_ops={"id": "invalid"}) def _test_filter(self, expected_res, field_names, field_values, operations=()): From 4c817e01148820fd4bc5f7123bf1081c081bd597 Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 12:44:11 +0500 Subject: [PATCH 3/8] Fixed readme --- README.md | 47 ++++++++++++++-------- src/django_pg_bulk_update/set_functions.py | 2 +- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 9c3f8f4..3f49fb8 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Django extension to update multiple table records with similar (but not equal) c * PostgreSQL 9.2+ Previous versions may also work, but haven't been tested. JSONB operations are available for PostgreSQL 9.4+. + INSERT .. ON CONFLICT is used for PostgreSQL 9.5+. ## Installation Install via pip: @@ -34,17 +35,21 @@ There are 3 query helpers in this library. There parameters are unified and desc Functions forms raw sql query for PostgreSQL. It's work is not guaranteed on other databases. Function returns number of updated records. -* `bulk_update_or_create(model, values, key_fields='id', using=None, set_functions=None, update=True, batch_size=None, batch_delay=0)` +* `bulk_update_or_create(model, values, key_fields='id', using=None, set_functions=None, update=True, key_is_unique=True, batch_size=None, batch_delay=0)` This function finds records by key_fields. It creates not existing records with data, given in values. If `update` flag is set, it updates existing records with data, given in values. - Update is performed with bulk_udpate function above, so function work is not guaranteed on PostgreSQL only. - Function is done in transaction in 3 queries: - + Search for existing records - + Create not existing records (if values have any) - + Update existing records (if values have any and `update` flag is set) + There are two ways, this function may work: + 1) Use INSERT ... ON CONFLICT statement. It is safe, but requires PostgreSQL 9.5+ and unique index on key fields. + This behavior is used by default. + 2) 3-query transaction: + + Search for existing records + + Create not existing records (if values have any) + + Update existing records (if values have any and `update` flag is set) + This behavior is used by default on PostgreSQL before 9.5 and if key_is_unique parameter is set to False. + Note that transactional update has a known [race condition issue](https://github.com/M1hacka/django-pg-bulk-update/issues/14) that can't be fixed. - Function returns a tuple, containing number of records inserted and records updated. + Function returns number of records inserted or updated by query. * `pdnf_clause(key_fields, field_values, key_fields_ops=())` Pure django implementation of principal disjunctive normal form. It is base on combining Q() objects. @@ -137,6 +142,9 @@ There are 3 query helpers in this library. There parameters are unified and desc * `update: bool` If flag is not set, bulk_update_or_create function will not update existing records, only creating not existing. +* `key_is_unique: bool` + Defaults to True. Settings this flag to False forces library to use 3-query transactional update_or_create. + * `field_values: Iterable[Union[Iterable[Any], dict]]` Field values to use in `pdnf_clause` function. They have simpler format than update functions. It can come in 2 formats: @@ -208,7 +216,7 @@ print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_fiel # ] -inserted, updated = bulk_update_or_create(TestModel, [{ +res = bulk_update_or_create(TestModel, [{ "id": 3, "name": "_concat1", "int_field": 4 @@ -218,8 +226,8 @@ inserted, updated = bulk_update_or_create(TestModel, [{ "int_field": 5 }], set_functions={'name': '||'}) -print(inserted, updated) -# Outputs: 1, 1 +print(res) +# Outputs: 2 print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_field"))) # Outputs: [ @@ -294,7 +302,7 @@ You can define your own clause operator, creating `AbstractClauseOperator` subcl In order to simplify method usage of simple `field value` operators, by default `get_sql()` forms this condition, calling `get_sql_operator()` method, which returns . -Optionally, you can change `def format_field_value(self, field, val, connection, **kwargs)` method, +Optionally, you can change `def format_field_value(self, field, val, connection, cast_type=True, **kwargs)` method, which formats value according to field rules Example: @@ -336,16 +344,17 @@ You can define your own set function, creating `AbstractSetFunction` subclass an * `names` attribute * `supported_field_classes` attribute * One of: - - `def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs)` method + - `def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs)` method This method defines new value to set for parameter. It is called from `get_sql(...)` method by default. - - `def get_sql(self, field, val, connection, val_as_param=True, **kwargs)` method + - `def get_sql(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs)` method This method sets full sql and it params to use in set section of update query. By default it returns: `"%s" = self.get_sql_value(...)`, params Optionally, you can change: -* `def format_field_value(self, field, val, connection, **kwargs)` method, if input data needs special formatting. +* `def format_field_value(self, field, val, connection, cast_type=False, **kwargs)` method, if input data needs special formatting. * `def modify_create_params(self, model, key, kwargs)` method, to change data before passing them to model constructor -in `bulk_update_or_create()` +in `bulk_update_or_create()`. This method is used in 3-query transactional update only. INSERT ... ON CONFLICT +uses for_update flag of `get_sql()` and `get_sql_value()` functions Example: @@ -360,7 +369,7 @@ class CustomSetFunction(AbstractSetFunction): # Names of django field classes, this function supports. You can set None (default) to support any field. supported_field_classes = {'IntegerField', 'FloatField', 'AutoField', 'BigAutoField'} - def get_sql_value(self, field, val, connection, val_as_param=True, **kwargs): + def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs): """ Returns value sql to set into field and parameters for query execution This method is called from get_sql() by default. @@ -369,6 +378,8 @@ class CustomSetFunction(AbstractSetFunction): :param connection: Connection used to update data :param val_as_param: If flag is not set, value should be converted to string and inserted into query directly. Otherwise a placeholder and query parameter will be used + :param with_table: If flag is set, column name in sql is prefixed by table name + :param for_update: If flag is set, returns update sql. Otherwise - insert SQL :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor """ @@ -403,7 +414,7 @@ Library supports django.contrib.postgres.fields: + HStoreField Note that ArrayField and HStoreField are available since django 1.8, JSONField - since django 1.9. -Also PostgreSQL before 9.4 doesn't support jsonb, and so - JSONField. +PostgreSQL before 9.4 doesn't support jsonb, and so - JSONField. PostgreSQL 9.4 supports JSONB, but doesn't support concatenation operator (||). In order to support this set function a special function for postgres 9.4 was written. Add a migration to create it: @@ -419,6 +430,8 @@ class Migration(migrations.Migration): ] ``` +PostgreSQL before 9.5 doesn't support INSERT ... ON CONFLICT statement. So 3-query transactional update will be used. + ## Performance Test background: - Django 2.0.2 diff --git a/src/django_pg_bulk_update/set_functions.py b/src/django_pg_bulk_update/set_functions.py index 17729ed..ea58499 100644 --- a/src/django_pg_bulk_update/set_functions.py +++ b/src/django_pg_bulk_update/set_functions.py @@ -117,12 +117,12 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa """ Returns value sql to set into field and parameters for query execution This method is called from get_sql() by default. - :param with_table: If flag is set, column name in sql is prefixed by table name :param field: Django field to take format from :param val: Value to format :param connection: Connection used to update data :param val_as_param: If flag is not set, value should be converted to string and inserted into query directly. Otherwise a placeholder and query parameter will be used + :param with_table: If flag is set, column name in sql is prefixed by table name :param for_update: If flag is set, returns update sql. Otherwise - insert SQL :param kwargs: Additional arguments, if needed :return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor From b63b450531b65a52a74f66ba700b02b5210037ee Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 12:44:58 +0500 Subject: [PATCH 4/8] Changed version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f69f9d7..f88fd4c 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setup( name='django-pg-bulk-update', - version='1.1.0', + version='2.0.0', packages=['django_pg_bulk_update'], package_dir={'': 'src'}, url='https://github.com/M1hacka/django-pg-bulk-update', From 62ec48067978accaae1ef85940ff4585b2570603 Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 12:52:14 +0500 Subject: [PATCH 5/8] 1) Fixed tests 2) Added django 2.1 to travis --- .travis.yml | 25 +++++++++++++++++++++++++ runtests.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 6b977a6..445439b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,6 +43,11 @@ env: - DJANGO=2.0 PG=9.4 - DJANGO=2.0 PG=9.5 - DJANGO=2.0 PG=9.6 + - DJANGO=2.1 PG=9.2 + - DJANGO=2.1 PG=9.3 + - DJANGO=2.2 PG=9.4 + - DJANGO=2.1 PG=9.5 + - DJANGO=2.1 PG=9.6 matrix: exclude: @@ -57,6 +62,16 @@ matrix: env: DJANGO=2.0 PG=9.5 - python: 2.7 env: DJANGO=2.0 PG=9.6 + - python: 2.7 + env: DJANGO=2.1 PG=9.2 + - python: 2.7 + env: DJANGO=2.1 PG=9.3 + - python: 2.7 + env: DJANGO=2.1 PG=9.4 + - python: 2.7 + env: DJANGO=2.1 PG=9.5 + - python: 2.7 + env: DJANGO=2.1 PG=9.6 # Django 1.9+ doesn't support python 3.3 - python: 3.3 @@ -99,6 +114,16 @@ matrix: env: DJANGO=2.0 PG=9.5 - python: 3.3 env: DJANGO=2.0 PG=9.6 + - python: 3.3 + env: DJANGO=2.1 PG=9.2 + - python: 3.3 + env: DJANGO=2.1 PG=9.3 + - python: 3.3 + env: DJANGO=2.1 PG=9.4 + - python: 3.3 + env: DJANGO=2.1 PG=9.5 + - python: 3.3 + env: DJANGO=2.1 PG=9.6 # Django 1.7 doesn't support python 3.5+ - python: 3.5 diff --git a/runtests.py b/runtests.py index 45a349d..674b064 100644 --- a/runtests.py +++ b/runtests.py @@ -15,7 +15,7 @@ if __name__ == "__main__": print('Django: ', django.VERSION) print('Python: ', sys.version) - os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.test_settings' + os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings' django.setup() TestRunner = get_runner(settings) test_runner = TestRunner() From 82080142e3ec5e1c71ca371bbac4b1898ddfbc1c Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 13:17:23 +0500 Subject: [PATCH 6/8] Fixed python 2.7 compatibility --- src/django_pg_bulk_update/compatibility.py | 18 +++++++++++++++++- src/django_pg_bulk_update/query.py | 5 +++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/django_pg_bulk_update/compatibility.py b/src/django_pg_bulk_update/compatibility.py index cdd6f0c..19c4c6c 100644 --- a/src/django_pg_bulk_update/compatibility.py +++ b/src/django_pg_bulk_update/compatibility.py @@ -2,7 +2,9 @@ This file contains number of functions to handle different software versions compatibility """ import json -from typing import Dict, Any, Optional, Union, Tuple + +from django.db.models import Model, Field +from typing import Dict, Any, Optional, Union, Tuple, List, Type import django from django.db import connection, connections, models, DefaultConnectionProxy, migrations @@ -92,6 +94,20 @@ def get_field_db_type(field, conn): return db_type.replace('serial', 'integer') +def get_model_fields(model): + # type: (Type[Model]) -> List[Field] + """ + Returns all model fields. + :param model: Model to get fields for + :return: A list of fields + """ + if hasattr(model._meta, 'get_fields'): + # Django 1.8+ + return model._meta.get_fields() + else: + return [f[0] for f in model._meta.get_fields_with_model()] + + # Postgres 9.4 has JSONB support, but doesn't support concat operator (||) # So I've taken function to solve the problem from # https://stackoverflow.com/questions/30101603/merging-concatenating-jsonb-columns-in-query diff --git a/src/django_pg_bulk_update/query.py b/src/django_pg_bulk_update/query.py index 9e7904f..62a8b02 100644 --- a/src/django_pg_bulk_update/query.py +++ b/src/django_pg_bulk_update/query.py @@ -13,7 +13,7 @@ from django.db.models import Model, Q from typing import Any, Type, Iterable as TIterable, Union, Optional, List, Tuple -from .compatibility import get_postgres_version +from .compatibility import get_postgres_version, get_model_fields from .set_functions import AbstractSetFunction from .types import TOperators, TFieldNames, TUpdateValues, TSetFunctions, TOperatorsValid, TUpdateValuesValid, \ TSetFunctionsValid, FieldDescriptor @@ -254,7 +254,8 @@ def _get_default_fds(model, existing_fds): """ existing_fields = {fd.get_field(model) for fd in existing_fds} result = [] - for f in model._meta.get_fields(): + + for f in get_model_fields(model): if f not in existing_fields: desc = FieldDescriptor(f.attname) desc.set_prefix('def') From 0270c9b0ce7c212ed6c610105ab7246e72b29e37 Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 13:45:03 +0500 Subject: [PATCH 7/8] 1) Some refactoring 2) Fixed bug in get_postgresql_version 3) Fixed postgres before 9.5 compatibility --- src/django_pg_bulk_update/compatibility.py | 4 ++-- src/django_pg_bulk_update/query.py | 4 ++-- src/django_pg_bulk_update/set_functions.py | 2 +- tests/migrations/0001_initial.py | 2 +- tests/test_bulk_update_or_create.py | 1 + 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/django_pg_bulk_update/compatibility.py b/src/django_pg_bulk_update/compatibility.py index 19c4c6c..a23c1e9 100644 --- a/src/django_pg_bulk_update/compatibility.py +++ b/src/django_pg_bulk_update/compatibility.py @@ -30,7 +30,7 @@ def jsonb_available(): # type: () -> bool It is available since django 1.9 and doesn't support Postgres < 9.4 :return: Bool """ - return get_postgres_version(as_tuple=False) >= 90400 and (django.VERSION[0] > 1 or django.VERSION[1] > 8) + return get_postgres_version() >= (9, 4) and (django.VERSION[0] > 1 or django.VERSION[1] > 8) def hstore_available(): # type: () -> bool @@ -76,7 +76,7 @@ def get_postgres_version(using=None, as_tuple=True): """ conn = connection if using is None else connections[using] num = conn.cursor().connection.server_version - return (num / 10000, num % 10000 / 100, num % 100) if as_tuple else num + return (int(num / 10000), int(num % 10000 / 100), num % 100) if as_tuple else num def get_field_db_type(field, conn): diff --git a/src/django_pg_bulk_update/query.py b/src/django_pg_bulk_update/query.py index 62a8b02..86fb5f9 100644 --- a/src/django_pg_bulk_update/query.py +++ b/src/django_pg_bulk_update/query.py @@ -499,7 +499,7 @@ def _bulk_update_or_create_no_validation(model, values, key_fds, upd_fds, using, for key, updates in values.items(): if key in existing_values_dict: # Form a list of updates, if they are enabled - if update: + if update and upd_fds: update_items[key] = updates else: # Form a list of model objects for bulk_create() method @@ -693,7 +693,7 @@ def bulk_update_or_create(model, values, key_fields='id', using=None, set_functi upd_fds = _validate_set_functions(model, upd_fds, set_functions) # Insert on conflict is supported in PostgreSQL 9.5 and only with constraint - if get_postgres_version(using=using) > (9, 4) and key_is_unique: + if get_postgres_version(using=using) >= (9, 5) and key_is_unique: batch_func = _insert_on_conflict_no_validation else: batch_func = _bulk_update_or_create_no_validation diff --git a/src/django_pg_bulk_update/set_functions.py b/src/django_pg_bulk_update/set_functions.py index ea58499..e339494 100644 --- a/src/django_pg_bulk_update/set_functions.py +++ b/src/django_pg_bulk_update/set_functions.py @@ -273,7 +273,7 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa # Note, that function should be created before using this operator if not for_update: tpl = '%s' - elif get_postgres_version(as_tuple=False) < 90500 and isinstance(field, JSONField): + elif get_postgres_version() < (9, 5) and isinstance(field, JSONField): tpl = '{0}(COALESCE(%s, %s), %s)'.format(Postgres94MergeJSONBMigration.FUNCTION_NAME) else: tpl = 'COALESCE(%s, %s) || %s' diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index b89f03e..8c2af4f 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -46,5 +46,5 @@ class Migration(migrations.Migration): from django.contrib.postgres.operations import HStoreExtension Migration.operations = [HStoreExtension()] + Migration.operations -if jsonb_available() and get_postgres_version(as_tuple=False) < 90500: +if jsonb_available() and get_postgres_version() < (9, 5): Migration.operations += [Postgres94MergeJSONBMigration()] diff --git a/tests/test_bulk_update_or_create.py b/tests/test_bulk_update_or_create.py index e780ed6..72b5b3f 100644 --- a/tests/test_bulk_update_or_create.py +++ b/tests/test_bulk_update_or_create.py @@ -65,6 +65,7 @@ def test_key_fields(self): values[1]['id'] += 1 self.assertEqual(2, bulk_update_or_create(TestModel, values, key_fields=['id'])) values[1]['id'] += 1 + # All fields to update are in key_fields. So we can skip update self.assertEqual(1, bulk_update_or_create(TestModel, values, key_fields=['id', 'name'])) values[1]['id'] += 1 values[1]['name'] += '1' From 805d5ec6d29c04fdb52b555de0701546037cd24c Mon Sep 17 00:00:00 2001 From: M1ha Date: Sat, 8 Sep 2018 14:29:12 +0500 Subject: [PATCH 8/8] Django 2.1 doesn't support python 3.4 --- .travis.yml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 445439b..a86ae5f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,7 +45,7 @@ env: - DJANGO=2.0 PG=9.6 - DJANGO=2.1 PG=9.2 - DJANGO=2.1 PG=9.3 - - DJANGO=2.2 PG=9.4 + - DJANGO=2.1 PG=9.4 - DJANGO=2.1 PG=9.5 - DJANGO=2.1 PG=9.6 @@ -125,6 +125,19 @@ matrix: - python: 3.3 env: DJANGO=2.1 PG=9.6 + # Django 2.1 doesn't support python 3.4 + - python: 3.4 + env: DJANGO=2.1 PG=9.2 + - python: 3.4 + env: DJANGO=2.1 PG=9.3 + - python: 3.4 + env: DJANGO=2.1 PG=9.4 + - python: 3.4 + env: DJANGO=2.1 PG=9.5 + - python: 3.4 + env: DJANGO=2.1 PG=9.6 + + # Django 1.7 doesn't support python 3.5+ - python: 3.5 env: DJANGO=1.7 PG=9.2