diff --git a/peewee_async.py b/peewee_async.py index d426d1a..70239ef 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -811,6 +811,49 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False): await model.delete().where(query).aio_execute() return await type(self).delete().where(self._pk_expr()).aio_execute() + async def aio_save(self, force_insert=False, only=None): + field_dict = self.__data__.copy() + if self._meta.primary_key is not False: + pk_field = self._meta.primary_key + pk_value = self._pk + else: + pk_field = pk_value = None + if only is not None: + field_dict = self._prune_fields(field_dict, only) + elif self._meta.only_save_dirty and not force_insert: + field_dict = self._prune_fields(field_dict, self.dirty_fields) + if not field_dict: + self._dirty.clear() + return False + + self._populate_unsaved_relations(field_dict) + rows = 1 + + if self._meta.auto_increment and pk_value is None: + field_dict.pop(pk_field.name, None) + + if pk_value is not None and not force_insert: + if self._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + if not field_dict: + raise ValueError('no data to save!') + rows = await self.update(**field_dict).where(self._pk_expr()).aio_execute() + elif pk_field is not None: + pk = await self.insert(**field_dict).aio_execute() + if pk is not None and (self._meta.auto_increment or + pk_value is None): + self._pk = pk + # Although we set the primary-key, do not mark it as dirty. + self._dirty.discard(pk_field.name) + else: + await self.insert(**field_dict).aio_execute() + + self._dirty -= set(field_dict) # Remove any fields we saved. + return rows + @classmethod async def aio_get(cls, *query, **filters): """Async version of **peewee.Model.get**""" @@ -835,10 +878,7 @@ async def aio_get_or_none(cls, *query, **filters): return None @classmethod - async def aio_create(cls, **data): - """INSERT new row into table and return corresponding model instance.""" - obj = cls(**data) - pk = await cls.insert(**dict(obj.__data__)).aio_execute() - if obj._pk is None: - obj._pk = pk - return obj + async def aio_create(cls, **query): + inst = cls(**query) + await inst.aio_save(force_insert=True) + return inst diff --git a/peewee_async_compat.py b/peewee_async_compat.py index adaa143..ecd4e3e 100644 --- a/peewee_async_compat.py +++ b/peewee_async_compat.py @@ -291,6 +291,10 @@ async def update(self, obj, only=None): :param only: (optional) the list/tuple of fields or field names to update """ + warnings.warn( + "`update` method is deprecated, use `AioModel.aio_save` method instead.", + DeprecationWarning + ) field_dict = dict(obj.__data__) pk_field = obj._meta.primary_key diff --git a/tests/aio_model/test_shortcuts.py b/tests/aio_model/test_shortcuts.py index 8c86b78..9816e2f 100644 --- a/tests/aio_model/test_shortcuts.py +++ b/tests/aio_model/test_shortcuts.py @@ -1,5 +1,6 @@ import uuid +import peewee import pytest from peewee import fn @@ -86,3 +87,25 @@ async def test_aio_delete_instance_with_fk(db): assert await TestModelAlpha.aio_get_or_none(id=alpha.id) is None assert await TestModelBeta.aio_get_or_none(id=beta.id) is None + + +@dbs_all +async def test_aio_save(db): + t = TestModel(text="text", data="data") + rows = await t.aio_save() + assert rows == 1 + assert t.id is not None + + assert await TestModel.aio_get_or_none(text="text", data="data") is not None + + +@dbs_all +async def test_aio_save__force_insert(db): + t = await TestModel.aio_create(text="text", data="data") + t.data = "data2" + await t.aio_save() + + assert await TestModel.aio_get_or_none(text="text", data="data2") is not None + + with pytest.raises(peewee.IntegrityError): + await t.aio_save(force_insert=True)