Skip to content

Commit

Permalink
refact
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquesfize committed May 27, 2024
1 parent 8ef252c commit 9930b3a
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions backend/geonature/utils/celery.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
from celery import Celery, Task
from geonature.utils.env import db
from geonature.utils.config import config
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session

from celery import Celery
import flask
from flask_sqlalchemy import SQLAlchemy


class SQLASessionTask(Task):
def __init__(self):
self.sessions = {}

def before_start(self, task_id, args, kwargs):
engine = create_engine(
config["SQLALCHEMY_DATABASE_URI"],
)
session_factory = sessionmaker(bind=engine)
self.sessions[task_id] = scoped_session(session_factory)
super().before_start(task_id, args, kwargs)

def after_return(self, status, retval, task_id, args, kwargs, einfo):
session = self.sessions.pop(task_id)
session.close()
super().after_return(status, retval, task_id, args, kwargs, einfo)
from geonature.utils.env import db


class FlaskCelery(Celery):
Expand All @@ -37,20 +14,20 @@ def __init__(self, *args, **kwargs):
self.init_app(kwargs["app"])

def patch_task(self):
TaskBase = self.Task
_celery = self

class ContextTask(SQLASessionTask):
class ContextTask(TaskBase):
abstract = True

def __call__(self, *args, **kwargs):
if hasattr(self, "app"):
with self.app.app_context():
return SQLASessionTask.__call__(self, *args, **kwargs)
if flask.has_app_context():
return SQLASessionTask.__call__(self, *args, **kwargs)
else:
if hasattr(_celery, "app"):
with _celery.app.app_context():
return SQLASessionTask.__call__(self, *args, **kwargs)
# No need for db.session.remove() since it is automatically closed
# by flask-sqlalchemy when exit the app context created
return TaskBase.__call__(self, *args, **kwargs)
else:
return TaskBase.__call__(self, *args, **kwargs)

self.Task = ContextTask

Expand Down

0 comments on commit 9930b3a

Please sign in to comment.