diff --git a/flask_resty/view.py b/flask_resty/view.py index b3baf1b..9351559 100644 --- a/flask_resty/view.py +++ b/flask_resty/view.py @@ -3,6 +3,7 @@ import flask from flask.views import MethodView from marshmallow import ValidationError, fields +from sqlalchemy import inspect from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Load from sqlalchemy.orm.exc import NoResultFound @@ -616,7 +617,7 @@ def paginate_list_query(self, query): return self.pagination.get_page(query, self) def get_item_or_404(self, id, **kwargs): - """Get an item by ID; raise a 404 if it not found. + """Get an item by ID; raise a 404 if not found. This will get an item by ID per `get_item` below. If no item is found, it will rethrow the `NoResultFound` exception as an HTTP 404. @@ -632,6 +633,23 @@ def get_item_or_404(self, id, **kwargs): return item + def get_item_or_none(self, id, **kwargs): + """Get an item by ID; return `None` if not found. + + This will get an item by ID per `get_item` below. If no item is found, + it will return `None`. + + :param id: The item ID. + :return: The item corresponding to the ID, if there is one. + :rtype: object + """ + try: + item = self.get_item(id, **kwargs) + except NoResultFound: + return None + + return item + def get_item( self, id, *, with_for_update=False, create_transient_stub=False, ): @@ -841,6 +859,14 @@ def update_item_raw(self, item, data): for key, value in data.items(): setattr(item, key, value) + def upsert_item(self, item, data): + if item is not None: + item = self.update_item(item, data) or item + else: + item = self.create_and_add_item(data) + + return item + def delete_item(self, item): """Delete an existing item. @@ -1082,19 +1108,17 @@ def upsert(self, id, *, with_for_update=False): :rtype: :py:class:`flask.Response` """ data_in = self.get_request_data(expected_id=id) + item = self.get_item_or_none(id, with_for_update=with_for_update) - try: - item = self.get_item(id, with_for_update=with_for_update) - except NoResultFound: - item = self.create_and_add_item(data_in) - self.commit() - - return self.make_created_response(item) - else: - item = self.update_item(item, data_in) or item - self.commit() + item = self.upsert_item(item, data_in) or item + updated = inspect(item).persistent + self.commit() - return self.make_item_response(item) + return ( + self.make_item_response(item) + if updated + else self.make_created_response(item) + ) def destroy(self, id): """Delete the item for the specified ID.