Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: assembling the app features in modular way #9129

Merged
merged 11 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/db-migration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ jobs:
cp .env.example .env
- name: Run DB Migration
env:
DEBUG: true
run: |
cd api
poetry run python -m flask upgrade-db
108 changes: 4 additions & 104 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,13 @@
import os
import sys

python_version = sys.version_info
if not ((3, 11) <= python_version < (3, 13)):
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
raise SystemExit(1)

from configs import dify_config

if not dify_config.DEBUG:
from gevent import monkey

monkey.patch_all()

import grpc.experimental.gevent

grpc.experimental.gevent.init_gevent()

import json
import threading
import time
import warnings

from flask import Response

from app_factory import create_app
from libs import threadings_utils, version_utils

# DO NOT REMOVE BELOW
from events import event_handlers # noqa: F401
from extensions.ext_database import db

# TODO: Find a way to avoid importing models here
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401

# DO NOT REMOVE ABOVE


warnings.simplefilter("ignore", ResourceWarning)

os.environ["TZ"] = "UTC"
# windows platform not support tzset
if hasattr(time, "tzset"):
time.tzset()

# preparation before creating app
version_utils.check_supported_python_version()
threadings_utils.apply_gevent_threading_patch()

# create app
app = create_app()
celery = app.extensions["celery"]

if dify_config.TESTING:
print("App is running in TESTING mode")


@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response


@app.route("/health")
def health():
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
status=200,
content_type="application/json",
)


@app.route("/threads")
def threads():
num_threads = threading.active_count()
threads = threading.enumerate()

thread_list = []
for thread in threads:
thread_name = thread.name
thread_id = thread.ident
is_alive = thread.is_alive()

thread_list.append(
{
"name": thread_name,
"id": thread_id,
"is_alive": is_alive,
}
)

return {
"pid": os.getpid(),
"thread_num": num_threads,
"threads": thread_list,
}


@app.route("/db-pool-stat")
def pool_stat():
engine = db.engine
return {
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
}


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
217 changes: 66 additions & 151 deletions api/app_factory.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,15 @@
import logging
import os
import time

from configs import dify_config

if not dify_config.DEBUG:
from gevent import monkey

monkey.patch_all()

import grpc.experimental.gevent

grpc.experimental.gevent.init_gevent()

import json

from flask import Flask, Response, request
from flask_cors import CORS
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized

import contexts
from commands import register_commands
from configs import dify_config
from extensions import (
ext_celery,
ext_code_based_extension,
ext_compress,
ext_database,
ext_hosting_provider,
ext_logging,
ext_login,
ext_mail,
ext_migrate,
ext_proxy_fix,
ext_redis,
ext_sentry,
ext_storage,
)
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
from services.account_service import AccountService


class DifyApp(Flask):
pass
from dify_app import DifyApp


# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
def create_flask_app_with_configs() -> DifyApp:
"""
create a raw flask app
with configs loaded from .env file
Expand All @@ -69,117 +29,72 @@ def create_flask_app_with_configs() -> Flask:
return dify_app


def create_app() -> Flask:
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
app.secret_key = dify_config.SECRET_KEY
initialize_extensions(app)
register_blueprints(app)
register_commands(app)

end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
return app


def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_logging.init_app(app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)
ext_proxy_fix.init_app(app)


# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")

decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")

logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account


@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
"""Called when a user logged in."""
if user:
contexts.tenant_id.set(user.current_tenant_id)


@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)


# register blueprint routers
def register_blueprints(app):
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp

CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
def initialize_extensions(app: DifyApp):
from extensions import (
ext_app_metrics,
ext_blueprints,
ext_celery,
ext_code_based_extension,
ext_commands,
ext_compress,
ext_database,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_mail,
ext_migrate,
ext_proxy_fix,
ext_redis,
ext_sentry,
ext_set_secretkey,
ext_storage,
ext_timezone,
ext_warnings,
)
app.register_blueprint(service_api_bp)

CORS(
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)

app.register_blueprint(web_bp)

CORS(
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)

app.register_blueprint(console_app_bp)

CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)

app.register_blueprint(inner_api_bp)
extensions = [
ext_timezone,
ext_logging,
ext_warnings,
ext_import_modules,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
ext_database,
ext_app_metrics,
ext_migrate,
ext_redis,
ext_storage,
ext_celery,
ext_login,
ext_mail,
ext_hosting_provider,
ext_sentry,
ext_proxy_fix,
ext_blueprints,
ext_commands,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logging.info(f"Skipped {short_name}")
continue

start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
12 changes: 0 additions & 12 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,15 +640,3 @@ def fix_app_site_missing():
break

click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))


def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(vdb_migrate)
app.cli.add_command(convert_to_agent_apps)
app.cli.add_command(add_qdrant_doc_id_index)
app.cli.add_command(create_tenant)
app.cli.add_command(upgrade_db)
app.cli.add_command(fix_app_site_missing)
5 changes: 0 additions & 5 deletions api/configs/deploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ class DeploymentConfig(BaseSettings):
default=False,
)

TESTING: bool = Field(
bowenliang123 marked this conversation as resolved.
Show resolved Hide resolved
description="Enable testing mode for running automated tests",
default=False,
)

EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
Expand Down
Loading