Skip to content

Commit

Permalink
Merge pull request canonical#175 from canonical/rollback-on-session-e…
Browse files Browse the repository at this point in the history
…rror

Rollback db actions when connection invalidated
  • Loading branch information
samhotep authored Aug 19, 2024
2 parents e8e8d9c + 4fc8a49 commit ca52a41
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
12 changes: 6 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Standard library
from contextlib import redirect_stderr
import io
import os
import unittest
import warnings
from contextlib import redirect_stderr

# Packages
from sqlalchemy_utils import database_exists, create_database
import flask_migrate

# Packages
from sqlalchemy_utils import create_database, database_exists

# Local
from tests.fixtures.models import make_models
Expand All @@ -27,14 +27,14 @@
functionality, but I don't know of a good way to do that right now.
"""

from webapp import auth
from tests.helpers import transparent_decorator
from webapp import auth

auth.authorization_required = transparent_decorator
os.environ["DATABASE_URL"] = os.environ["TEST_DATABASE_URL"]

from webapp.app import app, db # noqa: E402

from webapp.app import app # noqa: E402
from webapp.database import db # noqa: E402

# Create database if it doesn't exist
with app.app_context():
Expand Down
9 changes: 3 additions & 6 deletions webapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from apispec.ext.marshmallow import MarshmallowPlugin
from canonicalwebteam.flask_base.app import FlaskBase
from flask import jsonify, make_response
from flask_migrate import Migrate

from webapp.api_spec import WebappFlaskApiSpec
from webapp.commands import register_commands
from webapp.database import db
from webapp.database import init_db
from webapp.views import (
bulk_upsert_cve,
create_notice,
create_release,
delete_cve,
bulk_upsert_cve,
delete_notice,
delete_release,
get_cve,
Expand All @@ -27,7 +26,6 @@
update_release,
)


app = FlaskBase(
__name__,
"ubuntu-com-security-api",
Expand All @@ -48,8 +46,7 @@
}
)

db.init_app(app)
migrate = Migrate(app, db)
init_db(app)

register_commands(app)

Expand Down
14 changes: 13 additions & 1 deletion webapp/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,19 @@
To add the application context
"""

from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy # noqa: E402

from sqlalchemy import exc

db = SQLAlchemy()


def init_db(app):
db.init_app(app)
Migrate(app, db)

@app.errorhandler(exc.PendingRollbackError)
def handle_db_exceptions(error):
# log the error:
app.logger.error(error)
db.session.rollback()
2 changes: 1 addition & 1 deletion webapp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from sqlalchemy.exc import DataError, IntegrityError
from sqlalchemy.orm import Query, load_only, selectinload

from webapp.app import db
from webapp.auth import authorization_required
from webapp.database import db
from webapp.models import (
CVE,
STATUS_STATUSES,
Expand Down

0 comments on commit ca52a41

Please sign in to comment.