Skip to content

Commit

Permalink
Auto now fix/issue 55 (#62)
Browse files Browse the repository at this point in the history
1. Added ability to create set functions which doesn't require value (`needs_value` attribute)
2. Added `NowSetFunction`
3. Fixed [issue #55](#55)
4. Fixed tricky issue with checking default connection during library import, which caused wrong default connection settings (timezone, for instance)
5. Fixed django 4.0 JSONField deprecation warning
  • Loading branch information
M1ha-Shvn authored Nov 7, 2020
1 parent 681d875 commit 9a1693a
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 68 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ There are 4 query helpers in this library. There parameters are unified and desc
This function combines ArrayField value with previous one, removing duplicates.
- 'array_remove'
This function deletes value from ArrayField field using array_remove PSQL Function.
- 'now', 'NOW'
This function sets field value to `NOW()` database function. Doesn't require any value in `values` parameter.
- You can define your own set function. See section below.

Increment, union and concatenate functions concern NULL as default value.
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Django (>=1.7)
django>=1.7
psycopg2-binary
pytz
six
typing
psycopg2
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

setup(
name='django-pg-bulk-update',
version='3.3.0',
version='3.4.0',
packages=['django_pg_bulk_update'],
package_dir={'': 'src'},
url='https://github.com/M1hacka/django-pg-bulk-update',
Expand All @@ -23,5 +23,5 @@
description='Django extension, executing bulk update operations for PostgreSQL',
long_description=long_description,
long_description_content_type="text/markdown",
requires=requires
install_requires=requires
)
32 changes: 29 additions & 3 deletions src/django_pg_bulk_update/compatibility.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
This file contains number of functions to handle different software versions compatibility
"""
import importlib
import json

from django.db.models import Model, Field, BigIntegerField, IntegerField
from typing import Dict, Any, Optional, Union, Tuple, List, Type
import sys
from typing import Dict, Any, Optional, Union, Tuple, List, Type, Callable

import django
from django.db import connection, connections, models, migrations
from django.db.models import Model, Field, BigIntegerField, IntegerField

from .types import TDatabase

Expand Down Expand Up @@ -53,6 +54,31 @@ def array_available(): # type: () -> bool
return django.VERSION >= (1, 8)


def import_pg_field_or_dummy(field_name, available_func): # type: (str, Callable) -> Any
"""
Imports PostgreSQL specific field, if it is avaialbe. Otherwise returns dummy class
This is used to simplify isinstance(f, PGField) checks
:param field_name: Field name. It should have same case as class name
:param available_func: Function to check if field is available. Should return boolean
:return: Field class or dummy class
"""
if sys.version_info < (3,):
field_name = field_name.encode()

dummy_class = type(field_name, (), {})

if available_func():
# Since django 3.1 JSONField is moved to django.db.models
module_basic = importlib.import_module('django.db.models')
if hasattr(module_basic, field_name):
return getattr(module_basic, field_name, dummy_class)

module_pg = importlib.import_module('django.contrib.postgres.fields')
return getattr(module_pg, field_name, dummy_class)
else:
return dummy_class


def returning_available(raise_exception=False):
# type: (bool) -> bool
"""
Expand Down
58 changes: 41 additions & 17 deletions src/django_pg_bulk_update/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
from collections import Iterable
from itertools import chain
from logging import getLogger
from typing import Any, Type, Iterable as TIterable, Union, Optional, List, Tuple

import six
from django.db import transaction, connection, connections
from django.db.models import Model, Q, AutoField, Field
from typing import Any, Type, Iterable as TIterable, Union, Optional, List, Tuple

from django.db.models.sql import UpdateQuery
from django.db.models.sql.where import WhereNode

from .compatibility import get_postgres_version, get_model_fields, returning_available
from .set_functions import AbstractSetFunction
from .set_functions import AbstractSetFunction, NowSetFunction
from .types import TOperators, TFieldNames, TUpdateValues, TSetFunctions, TOperatorsValid, TUpdateValuesValid, \
TSetFunctionsValid, TDatabase, FieldDescriptor, AbstractFieldFormatter
from .utils import batched_operation
from .utils import batched_operation, is_auto_set_field


__all__ = ['pdnf_clause', 'bulk_update', 'bulk_update_or_create', 'bulk_create']
logger = getLogger('django-pg-bulk-update')
Expand Down Expand Up @@ -102,8 +102,8 @@ def _validate_operators(key_fds, operators):
return key_fds


def _validate_update_values(key_fds, values):
# type: (Tuple[FieldDescriptor], TUpdateValues) -> Tuple[Tuple[FieldDescriptor], TUpdateValuesValid]
def _validate_update_values(model, key_fds, values):
# type: (Type[Model], Tuple[FieldDescriptor], TUpdateValues) -> Tuple[Tuple[FieldDescriptor], TUpdateValuesValid]
"""
Parses and validates input data for bulk_update and bulk_update_or_create.
It can come in 2 forms:
Expand Down Expand Up @@ -182,6 +182,14 @@ def _validate_update_values(key_fds, values):
raise TypeError("'values' parameter must be dict or Iterable")

descriptors = tuple(FieldDescriptor(name) for name in upd_keys_tuple)
fd_names = {fd.name for fd in descriptors}

# Add field names which are added automatically
descriptors += tuple(
FieldDescriptor(f.name)
for f in get_model_fields(model)
if is_auto_set_field(f) and f.name not in fd_names
)

# Add prefix to all descriptors
for name in descriptors:
Expand All @@ -190,14 +198,14 @@ def _validate_update_values(key_fds, values):
return descriptors, result


def _validate_set_functions(model, upd_fds, functions):
def _validate_set_functions(model, fds, functions):
# type: (Type[Model], Tuple[FieldDescriptor], TSetFunctions) -> TSetFunctionsValid
"""
Validates set functions.
It should be a dict with field name as key and function name or AbstractSetFunction instance as value
Default set function is EqualSetFunction
:param model: Model updated
:param upd_fds: A tuple of FieldDescriptors to update. It will be modified.
:param fds: A tuple of FieldDescriptors to update. It will be modified.
:param functions: Functions to validate
:return: A tuple of FieldDescriptor objects with set functions.
"""
Expand All @@ -212,12 +220,28 @@ def _validate_set_functions(model, upd_fds, functions):
if not isinstance(v, (six.string_types, AbstractSetFunction)):
raise ValueError("'set_functions' values must be string or AbstractSetFunction instance")

for f in upd_fds:
f.set_function = functions.get(f.name)
if not f.set_function.field_is_supported(f.get_field(model)):
for f in fds:
field = f.get_field(model)
if getattr(field, 'auto_now', False):
f.set_function = NowSetFunction(if_null=False)
elif getattr(field, 'auto_now_add', False):
f.set_function = NowSetFunction(if_null=True)
else:
f.set_function = functions.get(f.name)

if not f.set_function.field_is_supported(field):
raise ValueError("'%s' doesn't support '%s' field" % (f.set_function.__class__.__name__, f.name))

return upd_fds
# Add functions which doesn't require values
fd_names = {fd.name for fd in fds}
no_value_fds = []
for k, v in functions.items():
if k not in fd_names:
fd = FieldDescriptor(k, set_function=v)
if not fd.set_function.needs_value:
no_value_fds.append(fd)

return fds + tuple(no_value_fds)


def _validate_where(model, where, using):
Expand Down Expand Up @@ -379,7 +403,7 @@ def _with_values_query_part(model, values, conn, key_fds, upd_fds, default_fds=(
upd_format_bases = tuple(fd.set_function for fd in upd_fds)
for keys, updates in values.items():
# For field sql and params
upd_values = [updates[fd.name] for fd in upd_fds]
upd_values = [updates[fd.name] for fd in upd_fds if fd.set_function.needs_value]
upd_sql_items, upd_params = _generate_fds_sql(conn, upd_fields, upd_format_bases, upd_values, first)
key_sql_items, key_params = _generate_fds_sql(conn, key_fields, key_format_bases, keys, first)

Expand All @@ -395,7 +419,7 @@ def _with_values_query_part(model, values, conn, key_fds, upd_fds, default_fds=(
)

sel_sql = ', '.join(
'"%s"' % fd.prefixed_name for fd in chain(key_fds, upd_fds)
'"%s"' % fd.prefixed_name for fd in chain(key_fds, upd_fds) if fd.set_function.needs_value
)

return tpl % (sel_sql, values_sql, defaults_sql), values_update_params + list(defaults_params)
Expand Down Expand Up @@ -586,7 +610,7 @@ def bulk_update(model, values, key_fields='id', using=None, set_functions=None,
raise ValueError("using parameter must be existing database alias")

key_fields = _validate_field_names(key_fields)
upd_fds, values = _validate_update_values(key_fields, values)
upd_fds, values = _validate_update_values(model, key_fields, values)
ret_fds = _validate_returning(model, returning)
where = _validate_where(model, where, using)

Expand Down Expand Up @@ -714,7 +738,7 @@ def bulk_create(model, values, using=None, set_functions=None, returning=None, b
if using is not None and using not in connections:
raise ValueError("using parameter must be None or existing database alias")

insert_fds, values = _validate_update_values(tuple(), values)
insert_fds, values = _validate_update_values(model, tuple(), values)
ret_fds = _validate_returning(model, returning)

if len(values) == 0:
Expand Down Expand Up @@ -930,7 +954,7 @@ def bulk_update_or_create(model, values, key_fields='id', using=None, set_functi
for i, f in enumerate(key_fds):
f.set_prefix('key', index=i)

upd_fds, values = _validate_update_values(key_fds, values)
upd_fds, values = _validate_update_values(model, key_fds, values)
ret_fds = _validate_returning(model, returning)

if len(values) == 0:
Expand Down
34 changes: 23 additions & 11 deletions src/django_pg_bulk_update/set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytz
from django.db.models import Field, Model

from .compatibility import get_postgres_version, jsonb_available, Postgres94MergeJSONBMigration, hstore_serialize,\
hstore_available
from .compatibility import get_postgres_version, jsonb_available, Postgres94MergeJSONBMigration, hstore_serialize, \
hstore_available, import_pg_field_or_dummy
from .types import TDatabase, AbstractFieldFormatter
from .utils import get_subclasses, format_field_value

Expand Down Expand Up @@ -65,22 +65,16 @@
}


# JSONField is available in django 1.9+ only
# I create fake class for previous version in order to just skip isinstance(item, JSONField) if branch
if jsonb_available():
from django.contrib.postgres.fields import JSONField
else:
class JSONField:
pass


class AbstractSetFunction(AbstractFieldFormatter):
names = set()

# If set function supports any field class, this should be None.
# Otherwise a set of class names supported
supported_field_classes = None

# If set functions doesn't need value from input, set this to False.
needs_value = True

def modify_create_params(self, model, key, kwargs):
# type: (Type[Model], str, Dict[str, Any]) -> Dict[str, Any]
"""
Expand Down Expand Up @@ -254,6 +248,7 @@ class ConcatSetFunction(AbstractSetFunction):

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
null_default, null_default_params = self._parse_null_default(field, connection, **kwargs)
JSONField = import_pg_field_or_dummy('JSONField', jsonb_available)

# Postgres 9.4 has JSONB support, but doesn't support concat operator (||)
# So I've taken function to solve the problem from
Expand Down Expand Up @@ -324,3 +319,20 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa
sql, params = self.format_field_value(field, field.get_default(), connection)

return sql, params


class NowSetFunction(AbstractSetFunction):
names = {'now', 'NOW'}
supported_field_classes = {'DateField', 'DateTimeField'}
needs_value = False

def __init__(self, if_null=False): # type: (bool) -> None
self._if_null = if_null
super(NowSetFunction, self).__init__()

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
if for_update and self._if_null:
default_value, default_params = self._get_field_column(field, with_table=with_table), tuple()
return "COALESCE(%s, NOW())" % default_value, default_params
else:
return 'NOW()', tuple()
5 changes: 3 additions & 2 deletions src/django_pg_bulk_update/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ class FieldDescriptor(object):
"""
This class is added in order to make passing parameters in queries easier
"""
__slots__ = ['name', '_set_function', '_key_operator', '_prefix']
__slots__ = ['name', 'auto_set', '_set_function', '_key_operator', '_prefix']

def __init__(self, name, set_function=None, key_operator=None):
# type: (str, TSetFunction, TOperator) -> None
self.name = name
self.set_function = set_function
self.key_operator = key_operator
self.auto_set = False
self._prefix = ''

def get_field(self, model):
Expand Down Expand Up @@ -138,4 +139,4 @@ def format_field_value(self, field, val, connection, cast_type=False, **kwargs):
:return: A tuple: sql, replacing value in update and a tuple of parameters to pass to cursor
"""
from .utils import format_field_value
return format_field_value(field, val, connection, cast_type=cast_type)
return format_field_value(field, val, connection, cast_type=cast_type)
27 changes: 10 additions & 17 deletions src/django_pg_bulk_update/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,11 @@
from django.db.models.sql.subqueries import UpdateQuery
from typing import TypeVar, Set, Any, Tuple, Iterable, Callable, Optional, List

from .compatibility import hstore_serialize, hstore_available, jsonb_available, get_field_db_type
from .compatibility import hstore_serialize, hstore_available, get_field_db_type, import_pg_field_or_dummy
from .types import TDatabase

logger = logging.getLogger('django-pg-bulk-update')

# JSONField is available in django 1.9+ only
# I create fake class for previous version in order to just skip isinstance(item, JSONField) if branch
if jsonb_available():
from django.contrib.postgres.fields import JSONField
else:
class JSONField:
pass

# django.contrib.postgres is available in django 1.8+ only
# I create fake class for previous version in order to just skip isinstance(item, HStoreField) if branch
if hstore_available():
from django.contrib.postgres.fields import HStoreField
else:
class HStoreField:
pass

T = TypeVar('T')


Expand Down Expand Up @@ -64,6 +48,7 @@ def format_field_value(field, val, conn, cast_type=False):
# And modified for our needs
query = UpdateQuery(field.model)
compiler = query.get_compiler(connection=conn)
HStoreField = import_pg_field_or_dummy('HStoreField', hstore_available)

if hasattr(val, 'resolve_expression'):
val = val.resolve_expression(query, allow_joins=False, for_save=True)
Expand Down Expand Up @@ -170,3 +155,11 @@ def _batches_iterator():

return results


def is_auto_set_field(field): # type: (Field) -> bool
"""
Checks if model fields should be set automatically if absent in values
:param field: Model field instance
:return: Boolean
"""
return getattr(field, 'auto_now', False) or getattr(field, 'auto_now_add', False)
10 changes: 10 additions & 0 deletions tests/fixtures/auto_now_model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[
{
"model": "tests.AutoNowModel",
"pk": 1,
"fields": {
"created": "2019-01-01T00:00:00+0000",
"updated": "2019-01-01"
}
}
]
4 changes: 2 additions & 2 deletions tests/migrations/0001_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db import migrations, models

from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available, \
get_postgres_version, Postgres94MergeJSONBMigration
get_postgres_version, Postgres94MergeJSONBMigration, import_pg_field_or_dummy

test_model_fields = [
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
Expand All @@ -14,7 +14,7 @@
]

if jsonb_available():
from django.contrib.postgres.fields import JSONField
JSONField = import_pg_field_or_dummy('JSONField', jsonb_available)
test_model_fields.append(('json_field', JSONField(null=True, blank=True)))

if array_available():
Expand Down
Loading

0 comments on commit 9a1693a

Please sign in to comment.