Skip to content

Commit

Permalink
refactor: use dify_config to replace legacy usage of flask app's conf…
Browse files Browse the repository at this point in the history
…ig (#9089)
  • Loading branch information
bowenliang123 authored Oct 22, 2024
1 parent 8f670f3 commit 4d9160c
Show file tree
Hide file tree
Showing 27 changed files with 221 additions and 207 deletions.
16 changes: 6 additions & 10 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from configs import dify_config

if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey

Expand Down Expand Up @@ -36,33 +38,27 @@
time.tzset()


# -------------
# Configuration
# -------------
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first


# create app
app = create_app()
celery = app.extensions["celery"]

if app.config.get("TESTING"):
if dify_config.TESTING:
print("App is running in TESTING mode")


@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response


@app.route("/health")
def health():
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
status=200,
content_type="application/json",
)
Expand Down
6 changes: 3 additions & 3 deletions api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def create_flask_app_with_configs() -> Flask:

def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
app.secret_key = dify_config.SECRET_KEY
initialize_extensions(app)
register_blueprints(app)
register_commands(app)
Expand Down Expand Up @@ -150,7 +150,7 @@ def register_blueprints(app):

CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
Expand All @@ -161,7 +161,7 @@ def register_blueprints(app):

CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
Expand Down
15 changes: 15 additions & 0 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ class SecurityConfig(BaseSettings):
default=5,
)

LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
)

ADMIN_API_KEY_ENABLE: bool = Field(
description="Whether to enable admin api key for authentication",
default=False,
)

ADMIN_API_KEY: Optional[str] = Field(
description="admin api key for authentication",
default=None,
)


class AppExecutionConfig(BaseSettings):
"""
Expand Down
6 changes: 3 additions & 3 deletions api/controllers/console/admin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from functools import wraps

from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound, Unauthorized

from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.wraps import only_edition_cloud
Expand All @@ -15,7 +15,7 @@
def admin_required(view):
@wraps(view)
def decorated(*args, **kwargs):
if not os.getenv("ADMIN_API_KEY"):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

auth_header = request.headers.get("Authorization")
Expand All @@ -31,7 +31,7 @@ def decorated(*args, **kwargs):
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")

if os.getenv("ADMIN_API_KEY") != auth_token:
if dify_config.ADMIN_API_KEY != auth_token:
raise Unauthorized("API key is invalid.")

return view(*args, **kwargs)
Expand Down
91 changes: 44 additions & 47 deletions api/core/hosting_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional

from flask import Config, Flask
from flask import Flask
from pydantic import BaseModel

from configs import dify_config
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
Expand Down Expand Up @@ -44,32 +45,30 @@ class HostingConfiguration:
moderation_config: HostedModerationConfig = None

def init_app(self, app: Flask) -> None:
config = app.config

if config.get("EDITION") != "CLOUD":
if dify_config.EDITION != "CLOUD":
return

self.provider_map["azure_openai"] = self.init_azure_openai(config)
self.provider_map["openai"] = self.init_openai(config)
self.provider_map["anthropic"] = self.init_anthropic(config)
self.provider_map["minimax"] = self.init_minimax(config)
self.provider_map["spark"] = self.init_spark(config)
self.provider_map["zhipuai"] = self.init_zhipuai(config)
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()

self.moderation_config = self.init_moderation_config(config)
self.moderation_config = self.init_moderation_config()

@staticmethod
def init_azure_openai(app_config: Config) -> HostingProvider:
def init_azure_openai() -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
credentials = {
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
"base_model_name": "gpt-35-turbo",
}

quotas = []
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
Expand Down Expand Up @@ -122,31 +121,31 @@ def init_azure_openai(app_config: Config) -> HostingProvider:
quota_unit=quota_unit,
)

def init_openai(self, app_config: Config) -> HostingProvider:
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas = []

if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)

if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)

if len(quotas) > 0:
credentials = {
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
}

if app_config.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
if dify_config.HOSTED_OPENAI_API_BASE:
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE

if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION

return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)

Expand All @@ -156,26 +155,26 @@ def init_openai(self, app_config: Config) -> HostingProvider:
)

@staticmethod
def init_anthropic(app_config: Config) -> HostingProvider:
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas = []

if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
quotas.append(trial_quota)

if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)

if len(quotas) > 0:
credentials = {
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
}

if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
if dify_config.HOSTED_ANTHROPIC_API_BASE:
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE

return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)

Expand All @@ -185,9 +184,9 @@ def init_anthropic(app_config: Config) -> HostingProvider:
)

@staticmethod
def init_minimax(app_config: Config) -> HostingProvider:
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_MINIMAX_ENABLED"):
if dify_config.HOSTED_MINIMAX_ENABLED:
quotas = [FreeHostingQuota()]

return HostingProvider(
Expand All @@ -203,9 +202,9 @@ def init_minimax(app_config: Config) -> HostingProvider:
)

@staticmethod
def init_spark(app_config: Config) -> HostingProvider:
def init_spark() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_SPARK_ENABLED"):
if dify_config.HOSTED_SPARK_ENABLED:
quotas = [FreeHostingQuota()]

return HostingProvider(
Expand All @@ -221,9 +220,9 @@ def init_spark(app_config: Config) -> HostingProvider:
)

@staticmethod
def init_zhipuai(app_config: Config) -> HostingProvider:
def init_zhipuai() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
if dify_config.HOSTED_ZHIPUAI_ENABLED:
quotas = [FreeHostingQuota()]

return HostingProvider(
Expand All @@ -239,17 +238,15 @@ def init_zhipuai(app_config: Config) -> HostingProvider:
)

@staticmethod
def init_moderation_config(app_config: Config) -> HostedModerationConfig:
if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",")
)
def init_moderation_config() -> HostedModerationConfig:
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))

return HostedModerationConfig(enabled=False)

@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
models_str = dify_config.model_dump().get(env_var)
models_list = models_str.split(",") if models_str else []
return [
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)
Expand Down
3 changes: 1 addition & 2 deletions api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))

config = current_app.config
return QdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=config.root_path,
root_path=current_app.config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
Expand Down
18 changes: 10 additions & 8 deletions api/extensions/ext_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from celery import Celery, Task
from flask import Flask

from configs import dify_config


def init_app(app: Flask) -> Celery:
class FlaskTask(Task):
Expand All @@ -12,19 +14,19 @@ def __call__(self, *args: object, **kwargs: object) -> object:

broker_transport_options = {}

if app.config.get("CELERY_USE_SENTINEL"):
if dify_config.CELERY_USE_SENTINEL:
broker_transport_options = {
"master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"),
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1),
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
},
}

celery_app = Celery(
app.name,
task_cls=FlaskTask,
broker=app.config.get("CELERY_BROKER_URL"),
backend=app.config.get("CELERY_BACKEND"),
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
task_ignore_result=True,
)

Expand All @@ -37,12 +39,12 @@ def __call__(self, *args: object, **kwargs: object) -> object:
}

celery_app.conf.update(
result_backend=app.config.get("CELERY_RESULT_BACKEND"),
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options,
broker_connection_retry_on_startup=True,
)

if app.config.get("BROKER_USE_SSL"):
if dify_config.BROKER_USE_SSL:
celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
)
Expand All @@ -54,7 +56,7 @@ def __call__(self, *args: object, **kwargs: object) -> object:
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
]
day = app.config.get("CELERY_BEAT_SCHEDULER_TIME")
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
"clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
Expand Down
Loading

0 comments on commit 4d9160c

Please sign in to comment.