-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/new-login
* main: (35 commits) fix #9409 (#9433) update dataset clean rule (#9426) add clean 7 days datasets (#9424) fix: resolve overlap issue with API Extension selector and modal (#9407) refactor: update the default values of top-k parameter in vdb to be consistent (#9367) fix: incorrect webapp image displayed (#9401) Fix/economical knowledge retrieval (#9396) feat: add timezone conversion for time tool (#9393) fix: Deprecated gemma2-9b model in Fireworks AI Provider (#9373) feat: storybook (#9324) fix: use gpt-4o-mini for validating credentials (#9387) feat: Enable baiduvector intergration test (#9369) fix: remove the stream option of zhipu and gemini (#9319) fix: add missing vikingdb param in docker .env.example (#9334) feat: add minimax abab6.5t support (#9365) fix: (#9336 followup) skip poetry preperation in style workflow when no change in api folder (#9362) feat: add glm-4-flashx, deprecated chatglm_turbo (#9357) fix: Azure OpenAI o1 max_completion_token and get_num_token_from_messages error (#9326) fix: In the output, the order of 'ta' is sometimes reversed as 'at'. #8015 (#8791) refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class (#9356) ...
- Loading branch information
Showing
250 changed files
with
7,606 additions
and
9,662 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,3 +85,4 @@ | |
cd ../ | ||
poetry run -C api bash dev/pytest/pytest_all_tests.sh | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,88 +1,24 @@ | ||
import logging | ||
from flask_restful import Resource | ||
|
||
from flask_login import current_user | ||
from flask_restful import Resource, marshal, reqparse | ||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | ||
|
||
import services | ||
from controllers.console import api | ||
from controllers.console.app.error import ( | ||
CompletionRequestError, | ||
ProviderModelCurrentlyNotSupportError, | ||
ProviderNotInitializeError, | ||
ProviderQuotaExceededError, | ||
) | ||
from controllers.console.datasets.error import DatasetNotInitializedError | ||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | ||
from controllers.console.setup import setup_required | ||
from controllers.console.wraps import account_initialization_required | ||
from core.errors.error import ( | ||
LLMBadRequestError, | ||
ModelCurrentlyNotSupportError, | ||
ProviderTokenNotInitError, | ||
QuotaExceededError, | ||
) | ||
from core.model_runtime.errors.invoke import InvokeError | ||
from fields.hit_testing_fields import hit_testing_record_fields | ||
from libs.login import login_required | ||
from services.dataset_service import DatasetService | ||
from services.hit_testing_service import HitTestingService | ||
|
||
|
||
class HitTestingApi(Resource): | ||
class HitTestingApi(Resource, DatasetsHitTestingBase): | ||
@setup_required | ||
@login_required | ||
@account_initialization_required | ||
def post(self, dataset_id): | ||
dataset_id_str = str(dataset_id) | ||
|
||
dataset = DatasetService.get_dataset(dataset_id_str) | ||
if dataset is None: | ||
raise NotFound("Dataset not found.") | ||
|
||
try: | ||
DatasetService.check_dataset_permission(dataset, current_user) | ||
except services.errors.account.NoPermissionError as e: | ||
raise Forbidden(str(e)) | ||
|
||
parser = reqparse.RequestParser() | ||
parser.add_argument("query", type=str, location="json") | ||
parser.add_argument("retrieval_model", type=dict, required=False, location="json") | ||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | ||
args = parser.parse_args() | ||
|
||
HitTestingService.hit_testing_args_check(args) | ||
|
||
try: | ||
response = HitTestingService.retrieve( | ||
dataset=dataset, | ||
query=args["query"], | ||
account=current_user, | ||
retrieval_model=args["retrieval_model"], | ||
external_retrieval_model=args["external_retrieval_model"], | ||
limit=10, | ||
) | ||
dataset = self.get_and_validate_dataset(dataset_id_str) | ||
args = self.parse_args() | ||
self.hit_testing_args_check(args) | ||
|
||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | ||
except services.errors.index.IndexNotInitializedError: | ||
raise DatasetNotInitializedError() | ||
except ProviderTokenNotInitError as ex: | ||
raise ProviderNotInitializeError(ex.description) | ||
except QuotaExceededError: | ||
raise ProviderQuotaExceededError() | ||
except ModelCurrentlyNotSupportError: | ||
raise ProviderModelCurrentlyNotSupportError() | ||
except LLMBadRequestError: | ||
raise ProviderNotInitializeError( | ||
"No Embedding Model or Reranking Model available. Please configure a valid provider " | ||
"in the Settings -> Model Provider." | ||
) | ||
except InvokeError as e: | ||
raise CompletionRequestError(e.description) | ||
except ValueError as e: | ||
raise ValueError(str(e)) | ||
except Exception as e: | ||
logging.exception("Hit testing failed.") | ||
raise InternalServerError(str(e)) | ||
return self.perform_hit_testing(dataset, args) | ||
|
||
|
||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import logging | ||
|
||
from flask_login import current_user | ||
from flask_restful import marshal, reqparse | ||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | ||
|
||
import services.dataset_service | ||
from controllers.console.app.error import ( | ||
CompletionRequestError, | ||
ProviderModelCurrentlyNotSupportError, | ||
ProviderNotInitializeError, | ||
ProviderQuotaExceededError, | ||
) | ||
from controllers.console.datasets.error import DatasetNotInitializedError | ||
from core.errors.error import ( | ||
LLMBadRequestError, | ||
ModelCurrentlyNotSupportError, | ||
ProviderTokenNotInitError, | ||
QuotaExceededError, | ||
) | ||
from core.model_runtime.errors.invoke import InvokeError | ||
from fields.hit_testing_fields import hit_testing_record_fields | ||
from services.dataset_service import DatasetService | ||
from services.hit_testing_service import HitTestingService | ||
|
||
|
||
class DatasetsHitTestingBase: | ||
@staticmethod | ||
def get_and_validate_dataset(dataset_id: str): | ||
dataset = DatasetService.get_dataset(dataset_id) | ||
if dataset is None: | ||
raise NotFound("Dataset not found.") | ||
|
||
try: | ||
DatasetService.check_dataset_permission(dataset, current_user) | ||
except services.errors.account.NoPermissionError as e: | ||
raise Forbidden(str(e)) | ||
|
||
return dataset | ||
|
||
@staticmethod | ||
def hit_testing_args_check(args): | ||
HitTestingService.hit_testing_args_check(args) | ||
|
||
@staticmethod | ||
def parse_args(): | ||
parser = reqparse.RequestParser() | ||
|
||
parser.add_argument("query", type=str, location="json") | ||
parser.add_argument("retrieval_model", type=dict, required=False, location="json") | ||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | ||
return parser.parse_args() | ||
|
||
@staticmethod | ||
def perform_hit_testing(dataset, args): | ||
try: | ||
response = HitTestingService.retrieve( | ||
dataset=dataset, | ||
query=args["query"], | ||
account=current_user, | ||
retrieval_model=args["retrieval_model"], | ||
external_retrieval_model=args["external_retrieval_model"], | ||
limit=10, | ||
) | ||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | ||
except services.errors.index.IndexNotInitializedError: | ||
raise DatasetNotInitializedError() | ||
except ProviderTokenNotInitError as ex: | ||
raise ProviderNotInitializeError(ex.description) | ||
except QuotaExceededError: | ||
raise ProviderQuotaExceededError() | ||
except ModelCurrentlyNotSupportError: | ||
raise ProviderModelCurrentlyNotSupportError() | ||
except LLMBadRequestError: | ||
raise ProviderNotInitializeError( | ||
"No Embedding Model or Reranking Model available. Please configure a valid provider " | ||
"in the Settings -> Model Provider." | ||
) | ||
except InvokeError as e: | ||
raise CompletionRequestError(e.description) | ||
except ValueError as e: | ||
raise ValueError(str(e)) | ||
except Exception as e: | ||
logging.exception("Hit testing failed.") | ||
raise InternalServerError(str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.