diff --git a/api/app.py b/api/app.py index a251ef5f0f72c3..eb98b7499106a8 100644 --- a/api/app.py +++ b/api/app.py @@ -1,312 +1,45 @@ -import os +from flask import Flask -if os.environ.get("DEBUG", "false").lower() != "true": - from gevent import monkey - - monkey.patch_all() - - import grpc.experimental.gevent - - grpc.experimental.gevent.init_gevent() - -import json -import logging -import sys -import threading -import time -import warnings -from logging.handlers import RotatingFileHandler - -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized - -import contexts -from commands import register_commands from configs import dify_config - -# DO NOT REMOVE BELOW -from events import event_handlers # noqa: F401 -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_login, - ext_mail, - ext_migrate, - ext_proxy_fix, - ext_redis, - ext_sentry, - ext_storage, -) -from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService - -# TODO: Find a way to avoid importing models here -from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 -from services.account_service import AccountService - -# DO NOT REMOVE ABOVE - - -warnings.simplefilter("ignore", ResourceWarning) - -os.environ["TZ"] = "UTC" -# windows platform not support tzset -if hasattr(time, "tzset"): - time.tzset() - - -class DifyApp(Flask): - pass - - -# ------------- -# Configuration -# ------------- - - -config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first - - -# ---------------------------- -# Application Factory Function -# ---------------------------- - - -def create_flask_app_with_configs() -> Flask: - """ - create a raw flask app - with configs loaded from .env file - """ - dify_app = DifyApp(__name__) - dify_app.config.from_mapping(dify_config.model_dump()) - - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - - return dify_app +from libs.threading import apply_gevent_threading_patch +from server.blueprints_assembly import BluePrintsAssembly +from server.commands_assembly import CommandsAssembly +from server.config_assembly import ConfigAssembly +from server.extensions_assembly import ExtensionsAssembly +from server.logger_assembly import LoggerAssembly +from server.module_assembly import PreloadModuleAssembly +from server.security_assembly import SecurityAssembly +from server.timezone_assembly import TimezoneAssembly def create_app() -> Flask: - app = create_flask_app_with_configs() - - app.secret_key = app.config["SECRET_KEY"] - - log_handlers = None - log_file = app.config.get("LOG_FILE") - if log_file: - log_dir = os.path.dirname(log_file) - os.makedirs(log_dir, exist_ok=True) - log_handlers = [ - RotatingFileHandler( - filename=log_file, - maxBytes=1024 * 1024 * 1024, - backupCount=5, - ), - logging.StreamHandler(sys.stdout), - ] - - logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config.get("LOG_FORMAT"), - datefmt=app.config.get("LOG_DATEFORMAT"), - handlers=log_handlers, - force=True, - ) - log_tz = app.config.get("LOG_TZ") - if log_tz: - from datetime import datetime - - import pytz - - timezone = pytz.timezone(log_tz) - - def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() - - for handler in logging.root.handlers: - handler.formatter.converter = time_converter - initialize_extensions(app) - register_blueprints(app) - register_commands(app) + dify_app = Flask(__name__) + + assemblies = [ + ConfigAssembly, + TimezoneAssembly, + LoggerAssembly, + SecurityAssembly, + PreloadModuleAssembly, + ExtensionsAssembly, + CommandsAssembly, + BluePrintsAssembly, + ] + for assem in assemblies: + assem().prepare_app(dify_app) - return app - - -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - ext_proxy_fix.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) - return logged_in_account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - - app.register_blueprint(inner_api_bp) + return dify_app # create app app = create_app() celery = app.extensions["celery"] -if app.config.get("TESTING"): - print("App is running in TESTING mode") - - -@app.after_request -def after_request(response): - """Add Version headers to the response.""" - response.set_cookie("remember_token", "", expires=0) - response.headers.add("X-Version", app.config["CURRENT_VERSION"]) - response.headers.add("X-Env", app.config["DEPLOY_ENV"]) - return response - - -@app.route("/health") -def health(): - return Response( - json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), - status=200, - content_type="application/json", - ) - - -@app.route("/threads") -def threads(): - num_threads = threading.active_count() - threads = threading.enumerate() - - thread_list = [] - for thread in threads: - thread_name = thread.name - thread_id = thread.ident - is_alive = thread.is_alive() - - thread_list.append( - { - "name": thread_name, - "id": thread_id, - "is_alive": is_alive, - } - ) - - return { - "pid": os.getpid(), - "thread_num": num_threads, - "threads": thread_list, - } - - -@app.route("/db-pool-stat") -def pool_stat(): - engine = db.engine - return { - "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, - } +if __name__ == "__main__": + if dify_config.DEBUG: + apply_gevent_threading_patch() + if dify_config.TESTING: + print("App is running in TESTING mode") -if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py new file mode 100644 index 00000000000000..02f14aa1a10259 --- /dev/null +++ b/api/extensions/ext_app_metrics.py @@ -0,0 +1,63 @@ +import json +import os +import threading + +from flask import Flask, Response + +from extensions.ext_database import db + + +def init_app(app: Flask): + @app.after_request + def after_request(response): + """Add Version headers to the response.""" + response.set_cookie("remember_token", "", expires=0) + response.headers.add("X-Version", app.config["CURRENT_VERSION"]) + response.headers.add("X-Env", app.config["DEPLOY_ENV"]) + return response + + @app.route("/health") + def health(): + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + status=200, + content_type="application/json", + ) + + @app.route("/threads") + def threads(): + num_threads = threading.active_count() + threads = threading.enumerate() + + thread_list = [] + for thread in threads: + thread_name = thread.name + thread_id = thread.ident + is_alive = thread.is_alive() + + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) + + return { + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, + } + + @app.route("/db-pool-stat") + def pool_stat(): + engine = db.engine + return { + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, + } diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index f7d5cffddadb18..ab12a101d87d6d 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,7 +1,51 @@ +import json + import flask_login +from flask import Response, request +from werkzeug.exceptions import Unauthorized -login_manager = flask_login.LoginManager() +import contexts +from libs.passport import PassportService +from services.account_service import AccountService def init_app(app): + login_manager = flask_login.LoginManager() login_manager.init_app(app) + + # Flask-Login configuration + @login_manager.request_loader + def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) + if logged_in_account: + contexts.tenant_id.set(logged_in_account.current_tenant_id) + return logged_in_account + + @login_manager.unauthorized_handler + def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) diff --git a/api/libs/threading.py b/api/libs/threading.py new file mode 100644 index 00000000000000..df2a21cb021330 --- /dev/null +++ b/api/libs/threading.py @@ -0,0 +1,16 @@ +def apply_gevent_threading_patch(): + """ + Run threading patch by gevent + to make standard library threading compatible. + Patching should be done as early as possible in the lifecycle of the program. + :return: + """ + # gevent + from gevent import monkey + + monkey.patch_all() + + # grpc gevent + from grpc.experimental import gevent as grpc_gevent + + grpc_gevent.init_gevent() diff --git a/api/server/basic_assembly.py b/api/server/basic_assembly.py new file mode 100644 index 00000000000000..f780cbde1b2fe8 --- /dev/null +++ b/api/server/basic_assembly.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from flask import Flask +from pydantic import BaseModel + + +class BasicAssembly(ABC, BaseModel): + @abstractmethod + def prepare_app(self, app: Flask): + raise NotImplementedError diff --git a/api/server/blueprints_assembly.py b/api/server/blueprints_assembly.py new file mode 100644 index 00000000000000..0c19825a61940f --- /dev/null +++ b/api/server/blueprints_assembly.py @@ -0,0 +1,53 @@ +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class BluePrintsAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + register_blueprints(app) + + +# register blueprint routers +def register_blueprints(app: Flask): + from flask_cors import CORS + + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/server/commands_assembly.py b/api/server/commands_assembly.py new file mode 100644 index 00000000000000..2643fd762aba2e --- /dev/null +++ b/api/server/commands_assembly.py @@ -0,0 +1,9 @@ +from flask import Flask + +from commands import register_commands +from server.basic_assembly import BasicAssembly + + +class CommandsAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + register_commands(app) diff --git a/api/server/config_assembly.py b/api/server/config_assembly.py new file mode 100644 index 00000000000000..bc42c20b87685f --- /dev/null +++ b/api/server/config_assembly.py @@ -0,0 +1,25 @@ +import os + +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class ConfigAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + prepare_flask_configs(app) + + +def prepare_flask_configs(app: Flask): + from configs import dify_config + + app.config.from_mapping(dify_config.model_dump()) + + # populate configs into system environment variables + for key, value in app.config.items(): + if isinstance(value, str): + os.environ[key] = value + elif isinstance(value, int | float | bool): + os.environ[key] = str(value) + elif value is None: + os.environ[key] = "" diff --git a/api/server/extensions_assembly.py b/api/server/extensions_assembly.py new file mode 100644 index 00000000000000..46d5c899ec799f --- /dev/null +++ b/api/server/extensions_assembly.py @@ -0,0 +1,43 @@ +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class ExtensionsAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + initialize_extensions(app) + + +def initialize_extensions(app: Flask): + from extensions import ( + ext_app_metrics, + ext_celery, + ext_code_based_extension, + ext_compress, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_storage, + ) + from extensions.ext_database import db + + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_compress.init_app(app) + ext_code_based_extension.init() + ext_database.init_app(app) + ext_app_metrics.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_login.init_app(app) + ext_mail.init_app(app) + ext_hosting_provider.init_app(app) + ext_sentry.init_app(app) + ext_proxy_fix.init_app(app) diff --git a/api/server/logger_assembly.py b/api/server/logger_assembly.py new file mode 100644 index 00000000000000..fa383659d7a116 --- /dev/null +++ b/api/server/logger_assembly.py @@ -0,0 +1,56 @@ +import logging +import os +import sys +import warnings + +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class LoggerAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + prepare_warnings() + prepare_logging(app) + + +def prepare_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + + +def prepare_logging(app: Flask): + from logging.handlers import RotatingFileHandler + + log_handlers = None + log_file = app.config.get("LOG_FILE") + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=1024 * 1024 * 1024, + backupCount=5, + ), + logging.StreamHandler(sys.stdout), + ] + logging.basicConfig( + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), + handlers=log_handlers, + force=True, + ) + log_tz = app.config.get("LOG_TZ") + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter diff --git a/api/server/module_assembly.py b/api/server/module_assembly.py new file mode 100644 index 00000000000000..6f35c2c8925585 --- /dev/null +++ b/api/server/module_assembly.py @@ -0,0 +1,13 @@ +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class PreloadModuleAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + preload_modules() + + +def preload_modules(): + from events import event_handlers # noqa: F401 + from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/server/security_assembly.py b/api/server/security_assembly.py new file mode 100644 index 00000000000000..2ed28e2a66a968 --- /dev/null +++ b/api/server/security_assembly.py @@ -0,0 +1,14 @@ +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class SecurityAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + prepare_secret_key(app) + + +def prepare_secret_key(app: Flask): + from configs import dify_config + + app.secret_key = dify_config.SECRET_KEY diff --git a/api/server/timezone_assembly.py b/api/server/timezone_assembly.py new file mode 100644 index 00000000000000..754061b373fa58 --- /dev/null +++ b/api/server/timezone_assembly.py @@ -0,0 +1,18 @@ +import os +import time + +from flask import Flask + +from server.basic_assembly import BasicAssembly + + +class TimezoneAssembly(BasicAssembly): + def prepare_app(self, app: Flask): + prepare_timezone() + + +def prepare_timezone(): + os.environ["TZ"] = "UTC" + # windows platform not support tzset + if hasattr(time, "tzset"): + time.tzset()