Skip to content

Commit

Permalink
Merge pull request #52 from M1hacka/issue-51
Browse files Browse the repository at this point in the history
Fixed issue #51
  • Loading branch information
M1ha-Shvn authored Dec 21, 2019
2 parents 3524cc9 + e39875f commit a0fbda7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

setup(
name='django-pg-bulk-update',
version='3.1.0',
version='3.1.1',
packages=['django_pg_bulk_update'],
package_dir={'': 'src'},
url='https://github.com/M1hacka/django-pg-bulk-update',
Expand Down
30 changes: 9 additions & 21 deletions src/django_pg_bulk_update/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,21 @@ def format_field_value(field, val, conn, cast_type=False):
"""
# This content is a part, taken from django.db.models.sql.compiler.SQLUpdateCompiler.as_sql()
# And modified for our needs
if hasattr(val, 'prepare_database_save'):
if field.remote_field:
val = field.get_db_prep_save(val.prepare_database_save(field), connection=conn)
else:
raise TypeError(
"Tried to update field %s with a model instance, %r. "
"Use a value compatible with %s."
% (field, val, field.__class__.__name__)
)
else:
val = field.get_db_prep_save(val, connection=conn)

# Getting the placeholder for the field.
query = UpdateQuery(field.model)
compiler = query.get_compiler(connection=conn)

if hasattr(val, 'resolve_expression'):
val = val.resolve_expression(query, allow_joins=False, for_save=True)
if val.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
raise FieldError(
'Aggregate functions are not allowed in this query '
'(%s=%r).' % (field.name, val)
)
if val.contains_over_clause:
raise FieldError('Window expressions are not allowed in this query.')
raise FieldError(
'Window expressions are not allowed in this query '
'(%s=%r).' % (field.name, val)
)
elif hasattr(val, 'prepare_database_save'):
if field.remote_field:
val = field.get_db_prep_save(val.prepare_database_save(field), connection=conn)
Expand All @@ -93,17 +86,12 @@ def format_field_value(field, val, conn, cast_type=False):
"Use a value compatible with %s."
% (field, val, field.__class__.__name__)
)
elif isinstance(field, JSONField):
# JSON field should be passed to execute() method as dict.
# If get_db_prep_save is called, it wraps it in JSONAdapter object
# When execute is done it tries wrapping it into JSONAdapter again and fails
pass
elif isinstance(field, HStoreField):
# Django before 1.10 doesn't convert HStoreField values to string automatically
# Which causes a bug in cursor.execute(). Let's do it here
if isinstance(val, dict):
val = hstore_serialize(val)
val = field.get_db_prep_save(val, connection=conn)
val = field.get_db_prep_save(val, connection=conn)
else:
val = field.get_db_prep_save(val, connection=conn)

Expand Down

0 comments on commit a0fbda7

Please sign in to comment.