Skip to content

Commit

Permalink
feat: added aio_save method (#252)
Browse files Browse the repository at this point in the history
* feat: added aio_save method
  • Loading branch information
kalombos authored Jun 15, 2024
1 parent fa356b8 commit 56a1afc
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
54 changes: 47 additions & 7 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,49 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False):
await model.delete().where(query).aio_execute()
return await type(self).delete().where(self._pk_expr()).aio_execute()

async def aio_save(self, force_insert=False, only=None):
field_dict = self.__data__.copy()
if self._meta.primary_key is not False:
pk_field = self._meta.primary_key
pk_value = self._pk
else:
pk_field = pk_value = None
if only is not None:
field_dict = self._prune_fields(field_dict, only)
elif self._meta.only_save_dirty and not force_insert:
field_dict = self._prune_fields(field_dict, self.dirty_fields)
if not field_dict:
self._dirty.clear()
return False

self._populate_unsaved_relations(field_dict)
rows = 1

if self._meta.auto_increment and pk_value is None:
field_dict.pop(pk_field.name, None)

if pk_value is not None and not force_insert:
if self._meta.composite_key:
for pk_part_name in pk_field.field_names:
field_dict.pop(pk_part_name, None)
else:
field_dict.pop(pk_field.name, None)
if not field_dict:
raise ValueError('no data to save!')
rows = await self.update(**field_dict).where(self._pk_expr()).aio_execute()
elif pk_field is not None:
pk = await self.insert(**field_dict).aio_execute()
if pk is not None and (self._meta.auto_increment or
pk_value is None):
self._pk = pk
# Although we set the primary-key, do not mark it as dirty.
self._dirty.discard(pk_field.name)
else:
await self.insert(**field_dict).aio_execute()

self._dirty -= set(field_dict) # Remove any fields we saved.
return rows

@classmethod
async def aio_get(cls, *query, **filters):
"""Async version of **peewee.Model.get**"""
Expand All @@ -835,10 +878,7 @@ async def aio_get_or_none(cls, *query, **filters):
return None

@classmethod
async def aio_create(cls, **data):
"""INSERT new row into table and return corresponding model instance."""
obj = cls(**data)
pk = await cls.insert(**dict(obj.__data__)).aio_execute()
if obj._pk is None:
obj._pk = pk
return obj
async def aio_create(cls, **query):
inst = cls(**query)
await inst.aio_save(force_insert=True)
return inst
4 changes: 4 additions & 0 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ async def update(self, obj, only=None):
:param only: (optional) the list/tuple of fields or
field names to update
"""
warnings.warn(
"`update` method is deprecated, use `AioModel.aio_save` method instead.",
DeprecationWarning
)
field_dict = dict(obj.__data__)
pk_field = obj._meta.primary_key

Expand Down
23 changes: 23 additions & 0 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid

import peewee
import pytest
from peewee import fn

Expand Down Expand Up @@ -86,3 +87,25 @@ async def test_aio_delete_instance_with_fk(db):

assert await TestModelAlpha.aio_get_or_none(id=alpha.id) is None
assert await TestModelBeta.aio_get_or_none(id=beta.id) is None


@dbs_all
async def test_aio_save(db):
t = TestModel(text="text", data="data")
rows = await t.aio_save()
assert rows == 1
assert t.id is not None

assert await TestModel.aio_get_or_none(text="text", data="data") is not None


@dbs_all
async def test_aio_save__force_insert(db):
t = await TestModel.aio_create(text="text", data="data")
t.data = "data2"
await t.aio_save()

assert await TestModel.aio_get_or_none(text="text", data="data2") is not None

with pytest.raises(peewee.IntegrityError):
await t.aio_save(force_insert=True)

0 comments on commit 56a1afc

Please sign in to comment.