From bc2a546eaab282cf035faadfd46d0b0b02393f84 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 2 Jan 2025 18:07:46 +0800 Subject: [PATCH 1/6] add knowledge rate limit --- .../console/datasets/datasets_segments.py | 1 + api/controllers/console/wraps.py | 33 ++++++++++++++++++- api/services/billing_service.py | 8 +++++ api/services/feature_service.py | 13 ++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 96654c09fd0223..24f76790df4a22 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -19,6 +19,7 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_knowledge_rate_limit_check, cloud_edition_billing_resource_check, setup_required, ) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 111db7ccf2da04..e22dc8ae6944c4 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,6 +1,8 @@ +from datetime import datetime import json import os from functools import wraps +import time from flask import abort, request from flask_login import current_user # type: ignore @@ -10,6 +12,7 @@ from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService +from extensions.ext_redis import redis_client from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout @@ -66,7 +69,7 @@ def decorated(*args, **kwargs): elif resource == "apps" and 0 < apps.limit <= apps.size: abort(403, "The number of apps has reached the limit of your subscription.") elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: - abort(403, "The capacity of the vector space has reached the limit of your subscription.") + abort(403, "The capacity of the knowledge storage space has reached the limit of your subscription.") elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, # so we need to check the source of the request from datasets @@ -111,6 +114,34 @@ def decorated(*args, **kwargs): return interceptor +def cloud_edition_billing_knowledge_rate_limit_check(): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( + current_user.current_tenant_id + ) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{current_user.current_tenant_id}" + + redis_client.zadd(key, {current_time: current_time}) + + redis_client.zremrangebyscore(key, 0, current_time - 60000) + + request_count = redis_client.zcard(key) + + if request_count > knowledge_rate_limit.limit: + abort(403, "The number of requests has reached the limit of your subscription.") + + return view(*args, **kwargs) + return view(*args, **kwargs) + + return decorated + + return interceptor + + def cloud_utm_record(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 3a13c10102fab8..c2866102fefb54 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -19,6 +19,14 @@ def get_info(cls, tenant_id: str): billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info + + @classmethod + def get_knowledge_rate_limit(cls, tenant_id: str): + params = {"tenant_id": tenant_id} + + knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) + + return knowledge_rate_limit.get("limit", 10) @classmethod def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): diff --git a/api/services/feature_service.py b/api/services/feature_service.py index b9261d19d7930e..3cc13d9c009f48 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -52,6 +52,11 @@ class FeatureModel(BaseModel): model_config = ConfigDict(protected_namespaces=()) +class KnowledgeRateLimitModel(BaseModel): + enabled: bool = False + limit: int = 10 + + class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" @@ -78,6 +83,14 @@ def get_features(cls, tenant_id: str) -> FeatureModel: cls._fulfill_params_from_billing_api(features, tenant_id) return features + + @classmethod + def get_knowledge_rate_limit(cls, tenant_id: str): + knowledge_rate_limit = KnowledgeRateLimitModel() + if dify_config.BILLING_ENABLED and tenant_id: + knowledge_rate_limit.enabled = True + knowledge_rate_limit.limit = BillingService.get_knowledge_rate_limit(tenant_id) + return knowledge_rate_limit @classmethod def get_system_features(cls) -> SystemFeatureModel: From cbd4223deab772cbf94bfb0b2db9e21cc4e320d0 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 6 Jan 2025 14:48:05 +0800 Subject: [PATCH 2/6] new SaaS Billing --- api/controllers/console/datasets/datasets.py | 10 +++++- .../console/datasets/datasets_document.py | 10 ++++++ .../console/datasets/datasets_segments.py | 12 ++++++- api/controllers/console/wraps.py | 30 +++++++++--------- api/controllers/service_api/wraps.py | 31 +++++++++++++++++++ 5 files changed, 75 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 0c0d2e20035b43..67741d5316af12 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,7 +10,12 @@ from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + enterprise_license_required, + setup_required, +) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType @@ -93,6 +98,7 @@ def get(self): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self): parser = reqparse.RequestParser() parser.add_argument( @@ -207,6 +213,7 @@ def get(self, dataset_id): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -310,6 +317,7 @@ def patch(self, dataset_id): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id_str = str(dataset_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 5a3c6f843290b8..1f28d3ffae47b3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -27,6 +27,7 @@ ) from controllers.console.wraps import ( account_initialization_required, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, ) @@ -230,6 +231,7 @@ def get(self, dataset_id): @account_initialization_required @marshal_with(documents_and_batch_fields) @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id = str(dataset_id) @@ -284,6 +286,7 @@ def post(self, dataset_id): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -307,6 +310,7 @@ class DatasetInitApi(Resource): @account_initialization_required @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: @@ -679,6 +683,7 @@ class DocumentProcessingApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -715,6 +720,7 @@ class DocumentDeleteApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) @@ -783,6 +789,7 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -878,6 +885,7 @@ class DocumentPauseApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id): """pause document.""" dataset_id = str(dataset_id) @@ -910,6 +918,7 @@ class DocumentRecoverApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id): """recover document.""" dataset_id = str(dataset_id) @@ -939,6 +948,7 @@ class DocumentRetryApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): """retry document.""" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 24f76790df4a22..034fe9cfe208f3 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -19,7 +19,7 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, - cloud_edition_billing_knowledge_rate_limit_check, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, ) @@ -107,6 +107,7 @@ def get(self, dataset_id, document_id): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -138,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -193,6 +195,7 @@ class DatasetDocumentSegmentAddApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -243,6 +246,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -303,6 +307,7 @@ def patch(self, dataset_id, document_id, segment_id): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -340,6 +345,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -406,6 +412,7 @@ class ChildChunkAddApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -504,6 +511,7 @@ def get(self, dataset_id, document_id, segment_id): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -547,6 +555,7 @@ class ChildChunkUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id, child_chunk_id): # check dataset dataset_id = str(dataset_id) @@ -591,6 +600,7 @@ def delete(self, dataset_id, document_id, segment_id, child_chunk_id): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id, child_chunk_id): # check dataset dataset_id = str(dataset_id) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index e22dc8ae6944c4..f416af5b117ac0 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,18 +1,17 @@ -from datetime import datetime import json import os -from functools import wraps import time +from functools import wraps from flask import abort, request from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError +from extensions.ext_redis import redis_client from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService -from extensions.ext_redis import redis_client from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout @@ -114,27 +113,26 @@ def decorated(*args, **kwargs): return interceptor -def cloud_edition_billing_knowledge_rate_limit_check(): +def cloud_edition_billing_rate_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( - current_user.current_tenant_id + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( + current_user.current_tenant_id ) - if knowledge_rate_limit.enabled: - current_time = int(time.time() * 1000) - key = f"rate_limit_{current_user.current_tenant_id}" - - redis_client.zadd(key, {current_time: current_time}) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{current_user.current_tenant_id}" - redis_client.zremrangebyscore(key, 0, current_time - 60000) + redis_client.zadd(key, {current_time: current_time}) - request_count = redis_client.zcard(key) + redis_client.zremrangebyscore(key, 0, current_time - 60000) - if request_count > knowledge_rate_limit.limit: - abort(403, "The number of requests has reached the limit of your subscription.") + request_count = redis_client.zcard(key) - return view(*args, **kwargs) + if request_count > knowledge_rate_limit.limit: + abort(403, "Sorry, you have reached the rate limit of your subscription.") return view(*args, **kwargs) return decorated diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 740b92ef8e4faf..804271f66eec7d 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,3 +1,4 @@ +import time from collections.abc import Callable from datetime import UTC, datetime from enum import Enum @@ -11,6 +12,7 @@ from werkzeug.exceptions import Forbidden, Unauthorized from extensions.ext_database import db +from extensions.ext_redis import redis_client from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.model import ApiToken, App, EndUser @@ -137,6 +139,35 @@ def decorated(*args, **kwargs): return interceptor +def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + api_token = validate_and_get_api_token(api_token_type) + + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( + api_token.tenant_id + ) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{api_token.tenant_id}" + + redis_client.zadd(key, {current_time: current_time}) + + redis_client.zremrangebyscore(key, 0, current_time - 60000) + + request_count = redis_client.zcard(key) + + if request_count > knowledge_rate_limit.limit: + raise Forbidden(403, "Sorry, you have reached the rate limit of your subscription.") + return view(*args, **kwargs) + + return decorated + + return interceptor + + def validate_dataset_token(view=None): def decorator(view): @wraps(view) From 014a1081dd54cd17cdfb61455c827c3ea3de834b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 6 Jan 2025 16:37:27 +0800 Subject: [PATCH 3/6] SaaS rate limit --- api/services/feature_service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 3cc13d9c009f48..f7697b3e564bfa 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -41,6 +41,7 @@ class FeatureModel(BaseModel): members: LimitationModel = LimitationModel(size=0, limit=1) apps: LimitationModel = LimitationModel(size=0, limit=10) vector_space: LimitationModel = LimitationModel(size=0, limit=5) + knowledge_rate_limit: int = 10 annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) docs_processing: str = "standard" @@ -156,7 +157,10 @@ def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str if "model_load_balancing_enabled" in billing_info: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] - + + if "knowledge_rate_limit" in billing_info: + features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] + @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() From bf070b45fc3a1fe0dd71906252a6bd0d17a39411 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 7 Jan 2025 13:58:59 +0800 Subject: [PATCH 4/6] SaaS rate limit --- api/controllers/console/datasets/hit_testing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 18b746f547287c..17be61093cafbf 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,7 +2,8 @@ from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, \ + cloud_edition_billing_rate_limit_check from libs.login import login_required @@ -10,6 +11,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id_str = str(dataset_id) From d028cf388273e90dfe824093fb3184e6c3350505 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 7 Jan 2025 14:41:11 +0800 Subject: [PATCH 5/6] SaaS rate limit --- api/controllers/console/wraps.py | 14 +++++++------ api/controllers/service_api/wraps.py | 10 +++++----- .../knowledge_retrieval_node.py | 20 +++++++++++++++++++ 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f416af5b117ac0..e92c0ae95200ea 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -68,7 +68,9 @@ def decorated(*args, **kwargs): elif resource == "apps" and 0 < apps.limit <= apps.size: abort(403, "The number of apps has reached the limit of your subscription.") elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: - abort(403, "The capacity of the knowledge storage space has reached the limit of your subscription.") + abort( + 403, "The capacity of the knowledge storage space has reached the limit of your subscription." + ) elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, # so we need to check the source of the request from datasets @@ -117,10 +119,8 @@ def cloud_edition_billing_rate_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): - if resource == "knowledge": - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( - current_user.current_tenant_id - ) + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) key = f"rate_limit_{current_user.current_tenant_id}" @@ -132,7 +132,9 @@ def decorated(*args, **kwargs): request_count = redis_client.zcard(key) if request_count > knowledge_rate_limit.limit: - abort(403, "Sorry, you have reached the rate limit of your subscription.") + abort( + 403, "Sorry, you have reached the knowledge base request rate limit of your subscription." + ) return view(*args, **kwargs) return decorated diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 804271f66eec7d..77754150699bf0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -145,10 +145,8 @@ def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) - if resource == "knowledge": - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit( - api_token.tenant_id - ) + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) key = f"rate_limit_{api_token.tenant_id}" @@ -160,7 +158,9 @@ def decorated(*args, **kwargs): request_count = redis_client.zcard(key) if request_count > knowledge_rate_limit.limit: - raise Forbidden(403, "Sorry, you have reached the rate limit of your subscription.") + raise Forbidden( + "Sorry, you have reached the knowledge base request rate limit of your subscription." + ) return view(*args, **kwargs) return decorated diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0f239af51ae79c..be82ad2a8290ab 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,4 +1,5 @@ import logging +import time from collections.abc import Mapping, Sequence from typing import Any, cast @@ -19,8 +20,10 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus +from services.feature_service import FeatureService from .entities import KnowledgeRetrievalNodeData from .exc import ( @@ -61,6 +64,23 @@ def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) + # check rate limit + if self.tenant_id: + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{self.tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error="Sorry, you have reached the knowledge base request rate limit of your subscription.", + error_type="RateLimitExceeded", + ) + # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) From 1fa0cd5118f955ba8bc72db3018dced147646e3b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 7 Jan 2025 14:41:26 +0800 Subject: [PATCH 6/6] SaaS rate limit --- api/controllers/console/datasets/hit_testing.py | 7 +++++-- api/services/billing_service.py | 2 +- api/services/feature_service.py | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 17be61093cafbf..d344e9d1267bfa 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,8 +2,11 @@ from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.wraps import account_initialization_required, setup_required, \ - cloud_edition_billing_rate_limit_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + setup_required, +) from libs.login import login_required diff --git a/api/services/billing_service.py b/api/services/billing_service.py index c2866102fefb54..e5c7821fce2a7e 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -19,7 +19,7 @@ def get_info(cls, tenant_id: str): billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info - + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str): params = {"tenant_id": tenant_id} diff --git a/api/services/feature_service.py b/api/services/feature_service.py index f7697b3e564bfa..52cfe4f2cb7e53 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -84,7 +84,7 @@ def get_features(cls, tenant_id: str) -> FeatureModel: cls._fulfill_params_from_billing_api(features, tenant_id) return features - + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str): knowledge_rate_limit = KnowledgeRateLimitModel() @@ -157,10 +157,10 @@ def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str if "model_load_balancing_enabled" in billing_info: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] - + if "knowledge_rate_limit" in billing_info: features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] - + @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info()