diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index b0322dd2b2da59..e80f9d30aadc02 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,5 +7,6 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc +echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc -source /home/vscode/.bashrc \ No newline at end of file +source /home/vscode/.bashrc diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 138fb886d6aea0..b4a6eb9adb7704 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -8,16 +8,9 @@ Please include a summary of the change and which issue is fixed. Please also inc # Screenshots - - - - - - - - - -
Before: After:
......
+| Before | After | +|--------|-------| +| ... | ... | # Checklist diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index e1c0bf33a4ff31..fd98db24b961b4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -50,9 +50,18 @@ jobs: - name: Run ModelRuntime run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh + - name: Run dify config tests + run: poetry run -C api python dev/pytest/pytest_config_tests.py + - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh + - name: Run mypy + run: | + pushd api + poetry run python -m mypy --install-types --non-interactive . + popd + - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index bc65c19a913fcf..d3146cd90dc02b 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml +yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" +echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 282afefe74243a..b5e63a8870baa8 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -37,6 +37,7 @@ jobs: - name: Ruff check if: steps.changed-files.outputs.any_changed == 'true' run: | + poetry run -C api ruff --version poetry run -C api ruff check ./api poetry run -C api ruff format --check ./api diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 73af3700637121..146bee95f21c5f 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -51,7 +51,7 @@ jobs: - name: Expose Service Ports run: sh .github/workflows/expose_service_ports.sh - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) + - name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | @@ -67,6 +67,7 @@ jobs: pgvector chroma elasticsearch + tidb - name: Test Vector Stores run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/api/.env.example b/api/.env.example index 52cdd9ecb28dc3..071a200e680278 100644 --- a/api/.env.example +++ b/api/.env.example @@ -56,20 +56,27 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase -STORAGE_TYPE=local -STORAGE_LOCAL_PATH=storage +# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase +STORAGE_TYPE=opendal + +# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage + +# S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region + # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key AZURE_BLOB_CONTAINER_NAME=yout-container-name AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net + # Aliyun oss Storage configuration ALIYUN_OSS_BUCKET_NAME=your-bucket-name ALIYUN_OSS_ACCESS_KEY=your-access-key @@ -79,6 +86,7 @@ ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path + # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string @@ -125,8 +133,8 @@ SUPABASE_URL=your-server-url WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* - -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +# Vector database configuration +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase VECTOR_STORE=weaviate # Weaviate configuration @@ -277,6 +285,7 @@ VIKINGDB_SOCKET_TIMEOUT=30 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_USERNAME=admin LINDORM_PASSWORD=admin +USING_UGC_INDEX=False # OceanBase Vector configuration OCEANBASE_VECTOR_HOST=127.0.0.1 @@ -295,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Model configuration -MULTIMODAL_SEND_IMAGE_FORMAT=base64 -MULTIMODAL_SEND_VIDEO_FORMAT=base64 +MULTIMODAL_SEND_FORMAT=base64 PROMPT_GENERATION_MAX_TOKENS=512 CODE_GENERATION_MAX_TOKENS=1024 @@ -381,6 +389,8 @@ LOG_FILE_BACKUP_COUNT=5 LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S # Log Timezone LOG_TZ=UTC +# Log format +LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s # Indexing configuration INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 @@ -389,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 # App configuration @@ -413,3 +424,7 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 CREATE_TIDB_SERVICE_JOB_ENABLED=false +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 +# Lockout duration in seconds +LOGIN_LOCKOUT_DURATION=86400 \ No newline at end of file diff --git a/api/.ruff.toml b/api/.ruff.toml index 0f3185223c1596..26a1b977a9f6ac 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -70,7 +70,6 @@ ignore = [ "SIM113", # eumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false - "SIM300", # yoda-conditions, ] [lint.per-file-ignores] diff --git a/api/app.py b/api/app.py index 996e2e890fdd10..c6a08290804a65 100644 --- a/api/app.py +++ b/api/app.py @@ -1,13 +1,30 @@ -from app_factory import create_app -from libs import threadings_utils, version_utils +from libs import version_utils # preparation before creating app version_utils.check_supported_python_version() -threadings_utils.apply_gevent_threading_patch() + + +def is_db_command(): + import sys + + if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": + return True + return False + # create app -app = create_app() -celery = app.extensions["celery"] +if is_db_command(): + from app_factory import create_migrations_app + + app = create_migrations_app() +else: + from app_factory import create_app + from libs import threadings_utils + + threadings_utils.apply_gevent_threading_patch() + + app = create_app() + celery = app.extensions["celery"] if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 7dc08c4d93960a..c0714116a3e692 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,5 +1,4 @@ import logging -import os import time from configs import dify_config @@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp: dify_app = DifyApp(__name__) dify_app.config.from_mapping(dify_config.model_dump()) - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - return dify_app @@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp): end_time = time.perf_counter() if dify_config.DEBUG: logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)") + + +def create_migrations_app(): + app = create_flask_app_with_configs() + from extensions import ext_database, ext_migrate + + # Initialize only required extensions + ext_database.init_app(app) + ext_migrate.init_app(app) + + return app diff --git a/api/commands.py b/api/commands.py index 09548ac9f338cd..59dfce68e0c92f 100644 --- a/api/commands.py +++ b/api/commands.py @@ -159,8 +159,7 @@ def migrate_annotation_vector_database(): try: # get apps info apps = ( - db.session.query(App) - .filter(App.status == "normal") + App.query.filter(App.status == "normal") .order_by(App.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -285,8 +284,7 @@ def migrate_knowledge_vector_database(): while True: try: datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") + Dataset.query.filter(Dataset.indexing_technique == "high_quality") .order_by(Dataset.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -450,7 +448,8 @@ def convert_to_agent_apps(): if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) + if app is not None: + apps.append(app) if len(apps) == 0: break @@ -555,14 +554,20 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str if language not in languages: language = "en-US" - name = name.strip() + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None # generate random password new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( @@ -620,6 +625,10 @@ def fix_app_site_missing(): try: app = db.session.query(App).filter(App.id == app_id).first() + if not app: + print(f"App {app_id} not found") + continue + tenant = app.tenant if tenant: accounts = tenant.get_accounts() diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 07ef6121cc5040..ac1ce9db100ea4 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,11 +1,51 @@ -from pydantic_settings import SettingsConfigDict +import logging +from typing import Any -from configs.deploy import DeploymentConfig -from configs.enterprise import EnterpriseFeatureConfig -from configs.extra import ExtraServiceConfig -from configs.feature import FeatureConfig -from configs.middleware import MiddlewareConfig -from configs.packaging import PackagingInfo +from pydantic.fields import FieldInfo +from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict + +from .deploy import DeploymentConfig +from .enterprise import EnterpriseFeatureConfig +from .extra import ExtraServiceConfig +from .feature import FeatureConfig +from .middleware import MiddlewareConfig +from .packaging import PackagingInfo +from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName +from .remote_settings_sources.apollo import ApolloSettingsSource + +logger = logging.getLogger(__name__) + + +class RemoteSettingsSourceFactory(PydanticBaseSettingsSource): + def __init__(self, settings_cls: type[BaseSettings]): + super().__init__(settings_cls) + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + raise NotImplementedError + + def __call__(self) -> dict[str, Any]: + current_state = self.current_state + remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME") + if not remote_source_name: + return {} + + remote_source: RemoteSettingsSource | None = None + match remote_source_name: + case RemoteSettingsSourceName.APOLLO: + remote_source = ApolloSettingsSource(current_state) + case _: + logger.warning(f"Unsupported remote source: {remote_source_name}") + return {} + + d: dict[str, Any] = {} + + for field_name, field in self.settings_cls.model_fields.items(): + field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name) + field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex) + if field_value is not None: + d[field_key] = field_value + + return d class DifyConfig( @@ -19,6 +59,8 @@ class DifyConfig( MiddlewareConfig, # Extra service configs ExtraServiceConfig, + # Remote source configs + RemoteSettingsSourceConfig, # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, @@ -35,3 +77,20 @@ class DifyConfig( # please consider to arrange it in the proper config group of existed or added # for better readability and maintainability. # Thanks for your concentration and consideration. + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + env_settings, + RemoteSettingsSourceFactory(settings_cls), + dotenv_settings, + file_secret_settings, + ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f1cb3efda7b3e3..74cdf944865796 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -239,7 +239,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") @@ -250,7 +249,6 @@ def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: ) @computed_field - @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") @@ -433,12 +431,28 @@ class WorkflowConfig(BaseSettings): default=5, ) + WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field( + description="Maximum allowed depth for nested parallel executions", + default=3, + ) + MAX_VARIABLE_SIZE: PositiveInt = Field( description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", default=200 * 1024, ) +class WorkflowNodeExecutionConfig(BaseSettings): + """ + Configuration for workflow node execution + """ + + MAX_SUBMIT_COUNT: PositiveInt = Field( + description="Maximum number of submitted thread count in a ThreadPool for parallel node execution", + default=100, + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -474,6 +488,11 @@ class AuthConfig(BaseSettings): default=60, ) + LOGIN_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -649,14 +668,9 @@ class IndexingConfig(BaseSettings): ) -class VisionFormatConfig(BaseSettings): - MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( - description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", - default="base64", - ) - - MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field( - description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64", +class MultiModalTransferConfig(BaseSettings): + MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( + description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64", default="base64", ) @@ -699,27 +713,27 @@ class PositionConfig(BaseSettings): default="", ) - @computed_field + @property def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} @@ -762,19 +776,20 @@ class FeatureConfig( FileAccessConfig, FileUploadConfig, HttpConfig, - VisionFormatConfig, InnerAPIConfig, IndexingConfig, LoggingConfig, MailConfig, ModelLoadBalanceConfig, ModerationConfig, + MultiModalTransferConfig, PositionConfig, RagEtlConfig, SecurityConfig, ToolConfig, UpdateConfig, WorkflowConfig, + WorkflowNodeExecutionConfig, WorkspaceConfig, LoginConfig, # hosted services config diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 57cc805ebf5a59..f6a44eaa471e62 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,54 +1,69 @@ -from typing import Any, Optional +from typing import Any, Literal, Optional from urllib.parse import quote_plus from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings -from configs.middleware.cache.redis_config import RedisConfig -from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig -from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig -from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig -from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig -from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig -from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig -from configs.middleware.storage.oci_storage_config import OCIStorageConfig -from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig -from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig -from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig -from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig -from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig -from configs.middleware.vdb.chroma_config import ChromaConfig -from configs.middleware.vdb.couchbase_config import CouchbaseConfig -from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig -from configs.middleware.vdb.lindorm_config import LindormConfig -from configs.middleware.vdb.milvus_config import MilvusConfig -from configs.middleware.vdb.myscale_config import MyScaleConfig -from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig -from configs.middleware.vdb.opensearch_config import OpenSearchConfig -from configs.middleware.vdb.oracle_config import OracleConfig -from configs.middleware.vdb.pgvector_config import PGVectorConfig -from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig -from configs.middleware.vdb.qdrant_config import QdrantConfig -from configs.middleware.vdb.relyt_config import RelytConfig -from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig -from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig -from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig -from configs.middleware.vdb.upstash_config import UpstashConfig -from configs.middleware.vdb.vikingdb_config import VikingDBConfig -from configs.middleware.vdb.weaviate_config import WeaviateConfig +from .cache.redis_config import RedisConfig +from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig +from .storage.amazon_s3_storage_config import S3StorageConfig +from .storage.azure_blob_storage_config import AzureBlobStorageConfig +from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig +from .storage.google_cloud_storage_config import GoogleCloudStorageConfig +from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig +from .storage.oci_storage_config import OCIStorageConfig +from .storage.opendal_storage_config import OpenDALStorageConfig +from .storage.supabase_storage_config import SupabaseStorageConfig +from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig +from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig +from .vdb.analyticdb_config import AnalyticdbConfig +from .vdb.baidu_vector_config import BaiduVectorDBConfig +from .vdb.chroma_config import ChromaConfig +from .vdb.couchbase_config import CouchbaseConfig +from .vdb.elasticsearch_config import ElasticsearchConfig +from .vdb.lindorm_config import LindormConfig +from .vdb.milvus_config import MilvusConfig +from .vdb.myscale_config import MyScaleConfig +from .vdb.oceanbase_config import OceanBaseVectorConfig +from .vdb.opensearch_config import OpenSearchConfig +from .vdb.oracle_config import OracleConfig +from .vdb.pgvector_config import PGVectorConfig +from .vdb.pgvectors_config import PGVectoRSConfig +from .vdb.qdrant_config import QdrantConfig +from .vdb.relyt_config import RelytConfig +from .vdb.tencent_vector_config import TencentVectorDBConfig +from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig +from .vdb.tidb_vector_config import TiDBVectorConfig +from .vdb.upstash_config import UpstashConfig +from .vdb.vikingdb_config import VikingDBConfig +from .vdb.weaviate_config import WeaviateConfig class StorageConfig(BaseSettings): - STORAGE_TYPE: str = Field( + STORAGE_TYPE: Literal[ + "opendal", + "s3", + "aliyun-oss", + "azure-blob", + "baidu-obs", + "google-storage", + "huawei-obs", + "oci-storage", + "tencent-cos", + "volcengine-tos", + "supabase", + "local", + ] = Field( description="Type of storage to use." - " Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', " - "'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.", - default="local", + " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', " + "'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", + default="opendal", ) STORAGE_LOCAL_PATH: str = Field( description="Path for local storage when STORAGE_TYPE is set to 'local'.", default="storage", + deprecated=True, ) @@ -73,7 +88,7 @@ class KeywordStoreConfig(BaseSettings): ) -class DatabaseConfig: +class DatabaseConfig(BaseSettings): DB_HOST: str = Field( description="Hostname or IP address of the database server.", default="localhost", @@ -115,7 +130,6 @@ class DatabaseConfig: ) @computed_field - @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -153,7 +167,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: ) @computed_field - @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { "pool_size": self.SQLALCHEMY_POOL_SIZE, @@ -191,7 +204,6 @@ class CeleryConfig(DatabaseConfig): ) @computed_field - @property def CELERY_RESULT_BACKEND(self) -> str | None: return ( "db+{}".format(self.SQLALCHEMY_DATABASE_URI) @@ -199,7 +211,6 @@ def CELERY_RESULT_BACKEND(self) -> str | None: else self.CELERY_BROKER_URL ) - @computed_field @property def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False @@ -235,6 +246,7 @@ class MiddlewareConfig( GoogleCloudStorageConfig, HuaweiCloudOBSStorageConfig, OCIStorageConfig, + OpenDALStorageConfig, S3StorageConfig, SupabaseStorageConfig, TencentCloudCOSStorageConfig, diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py index c511628a1514a7..e7913b0acc337c 100644 --- a/api/configs/middleware/storage/baidu_obs_storage_config.py +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class BaiduOBSStorageConfig(BaseModel): +class BaiduOBSStorageConfig(BaseSettings): """ Configuration settings for Baidu Object Storage Service (OBS) """ diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py index 3e9e7543ab2bab..be983b5187d271 100644 --- a/api/configs/middleware/storage/huawei_obs_storage_config.py +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class HuaweiCloudOBSStorageConfig(BaseModel): +class HuaweiCloudOBSStorageConfig(BaseSettings): """ Configuration settings for Huawei Cloud Object Storage Service (OBS) """ diff --git a/api/configs/middleware/storage/opendal_storage_config.py b/api/configs/middleware/storage/opendal_storage_config.py new file mode 100644 index 00000000000000..ef38070e53bb7c --- /dev/null +++ b/api/configs/middleware/storage/opendal_storage_config.py @@ -0,0 +1,9 @@ +from pydantic import Field +from pydantic_settings import BaseSettings + + +class OpenDALStorageConfig(BaseSettings): + OPENDAL_SCHEME: str = Field( + default="fs", + description="OpenDAL scheme.", + ) diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py index a3e905b21c63e9..dcf7c20cf9e057 100644 --- a/api/configs/middleware/storage/supabase_storage_config.py +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class SupabaseStorageConfig(BaseModel): +class SupabaseStorageConfig(BaseSettings): """ Configuration settings for Supabase Object Storage Service """ diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index 89ea8850023009..06c3ae4d3e63f8 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class VolcengineTOSStorageConfig(BaseModel): +class VolcengineTOSStorageConfig(BaseSettings): """ Configuration settings for Volcengine Tinder Object Storage (TOS) """ diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index 53cfaae43ef503..cb8dc7d724fff9 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field, PositiveInt +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings -class AnalyticdbConfig(BaseModel): +class AnalyticdbConfig(BaseSettings): """ Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL. Refer to the following documentation for details on obtaining credentials: diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py index 391089ec6e8d00..b81cbf895956ac 100644 --- a/api/configs/middleware/vdb/couchbase_config.py +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class CouchbaseConfig(BaseModel): +class CouchbaseConfig(BaseSettings): """ Couchbase configs """ diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index 0f6c6528066747..95e1d1cfca4b80 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -21,3 +21,14 @@ class LindormConfig(BaseSettings): description="Lindorm password", default=None, ) + DEFAULT_INDEX_TYPE: Optional[str] = Field( + description="Lindorm Vector Index Type, hnsw or flat is available in dify", + default="hnsw", + ) + DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" + ) + USING_UGC_INDEX: Optional[bool] = Field( + description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", + default=False, + ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index 5896c19d27d117..b5bf98b3aab25f 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel, Field, PositiveInt +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings -class MyScaleConfig(BaseModel): +class MyScaleConfig(BaseSettings): """ Configuration settings for MyScale vector database """ diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index 3e718481dc7e05..aba49ff6702ed8 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -1,9 +1,10 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic_settings import BaseSettings -class VikingDBConfig(BaseModel): +class VikingDBConfig(BaseSettings): """ Configuration for connecting to Volcengine VikingDB. Refer to the following documentation for details on obtaining credentials: diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index a2703ccb946390..4a168a3fb13947 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.13.1", + default="0.14.2", ) COMMIT_SHA: str = Field( diff --git a/api/configs/remote_settings_sources/__init__.py b/api/configs/remote_settings_sources/__init__.py new file mode 100644 index 00000000000000..4f3878d13b65c8 --- /dev/null +++ b/api/configs/remote_settings_sources/__init__.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import Field + +from .apollo import ApolloSettingsSourceInfo +from .base import RemoteSettingsSource +from .enums import RemoteSettingsSourceName + + +class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo): + REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field( + description="name of remote config source", + default="", + ) + + +__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"] diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py new file mode 100644 index 00000000000000..f02f7dc9ff6258 --- /dev/null +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -0,0 +1,55 @@ +from collections.abc import Mapping +from typing import Any, Optional + +from pydantic import Field +from pydantic.fields import FieldInfo +from pydantic_settings import BaseSettings + +from configs.remote_settings_sources.base import RemoteSettingsSource + +from .client import ApolloClient + + +class ApolloSettingsSourceInfo(BaseSettings): + """ + Packaging build information + """ + + APOLLO_APP_ID: Optional[str] = Field( + description="apollo app_id", + default=None, + ) + + APOLLO_CLUSTER: Optional[str] = Field( + description="apollo cluster", + default=None, + ) + + APOLLO_CONFIG_URL: Optional[str] = Field( + description="apollo config url", + default=None, + ) + + APOLLO_NAMESPACE: Optional[str] = Field( + description="apollo namespace", + default=None, + ) + + +class ApolloSettingsSource(RemoteSettingsSource): + def __init__(self, configs: Mapping[str, Any]): + self.client = ApolloClient( + app_id=configs["APOLLO_APP_ID"], + cluster=configs["APOLLO_CLUSTER"], + config_url=configs["APOLLO_CONFIG_URL"], + start_hot_update=False, + _notification_map={configs["APOLLO_NAMESPACE"]: -1}, + ) + self.namespace = configs["APOLLO_NAMESPACE"] + self.remote_configs = self.client.get_all_dicts(self.namespace) + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + if not isinstance(self.remote_configs, dict): + raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") + field_value = self.remote_configs.get(field_name) + return field_value, field_name, False diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py new file mode 100644 index 00000000000000..03c64ea00f0185 --- /dev/null +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -0,0 +1,304 @@ +import hashlib +import json +import logging +import os +import threading +import time +from collections.abc import Mapping +from pathlib import Path + +from .python_3x import http_request, makedirs_wrapper +from .utils import ( + CONFIGURATIONS, + NAMESPACE_NAME, + NOTIFICATION_ID, + get_value_from_dict, + init_ip, + no_key_cache_key, + signature, + url_encode_wrapper, +) + +logger = logging.getLogger(__name__) + + +class ApolloClient: + def __init__( + self, + config_url, + app_id, + cluster="default", + secret="", + start_hot_update=True, + change_listener=None, + _notification_map=None, + ): + # Core routing parameters + self.config_url = config_url + self.cluster = cluster + self.app_id = app_id + + # Non-core parameters + self.ip = init_ip() + self.secret = secret + + # Check the parameter variables + + # Private control variables + self._cycle_time = 5 + self._stopping = False + self._cache = {} + self._no_key = {} + self._hash = {} + self._pull_timeout = 75 + self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" + self._long_poll_thread = None + self._change_listener = change_listener # "add" "delete" "update" + if _notification_map is None: + _notification_map = {"application": -1} + self._notification_map = _notification_map + self.last_release_key = None + # Private startup method + self._path_checker() + if start_hot_update: + self._start_hot_update() + + # start the heartbeat thread + heartbeat = threading.Thread(target=self._heart_beat) + heartbeat.daemon = True + heartbeat.start() + + def get_json_from_net(self, namespace="application"): + url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( + self.config_url, self.app_id, self.cluster, namespace, "", self.ip + ) + try: + code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) + if code == 200: + if not body: + logger.error(f"get_json_from_net load configs failed, body is {body}") + return None + data = json.loads(body) + data = data["configurations"] + return_data = {CONFIGURATIONS: data} + return return_data + else: + return None + except Exception: + logger.exception("an error occurred in get_json_from_net") + return None + + def get_value(self, key, default_val=None, namespace="application"): + try: + # read memory configuration + namespace_cache = self._cache.get(namespace) + val = get_value_from_dict(namespace_cache, key) + if val is not None: + return val + + no_key = no_key_cache_key(namespace, key) + if no_key in self._no_key: + return default_val + + # read the network configuration + namespace_data = self.get_json_from_net(namespace) + val = get_value_from_dict(namespace_data, key) + if val is not None: + self._update_cache_and_file(namespace_data, namespace) + return val + + # read the file configuration + namespace_cache = self._get_local_cache(namespace) + val = get_value_from_dict(namespace_cache, key) + if val is not None: + self._update_cache_and_file(namespace_cache, namespace) + return val + + # If all of them are not obtained, the default value is returned + # and the local cache is set to None + self._set_local_cache_none(namespace, key) + return default_val + except Exception: + logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace) + return default_val + + # Set the key of a namespace to none, and do not set default val + # to ensure the real-time correctness of the function call. + # If the user does not have the same default val twice + # and the default val is used here, there may be a problem. + def _set_local_cache_none(self, namespace, key): + no_key = no_key_cache_key(namespace, key) + self._no_key[no_key] = key + + def _start_hot_update(self): + self._long_poll_thread = threading.Thread(target=self._listener) + # When the asynchronous thread is started, the daemon thread will automatically exit + # when the main thread is launched. + self._long_poll_thread.daemon = True + self._long_poll_thread.start() + + def stop(self): + self._stopping = True + logger.info("Stopping listener...") + + # Call the set callback function, and if it is abnormal, try it out + def _call_listener(self, namespace, old_kv, new_kv): + if self._change_listener is None: + return + if old_kv is None: + old_kv = {} + if new_kv is None: + new_kv = {} + try: + for key in old_kv: + new_value = new_kv.get(key) + old_value = old_kv.get(key) + if new_value is None: + # If newValue is empty, it means key, and the value is deleted. + self._change_listener("delete", namespace, key, old_value) + continue + if new_value != old_value: + self._change_listener("update", namespace, key, new_value) + continue + for key in new_kv: + new_value = new_kv.get(key) + old_value = old_kv.get(key) + if old_value is None: + self._change_listener("add", namespace, key, new_value) + except BaseException as e: + logger.warning(str(e)) + + def _path_checker(self): + if not os.path.isdir(self._cache_file_path): + makedirs_wrapper(self._cache_file_path) + + # update the local cache and file cache + def _update_cache_and_file(self, namespace_data, namespace="application"): + # update the local cache + self._cache[namespace] = namespace_data + # update the file cache + new_string = json.dumps(namespace_data) + new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest() + if self._hash.get(namespace) == new_hash: + pass + else: + file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt" + file_path.write_text(new_string) + self._hash[namespace] = new_hash + + # get the configuration from the local file + def _get_local_cache(self, namespace="application"): + cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") + if os.path.isfile(cache_file_path): + with open(cache_file_path) as f: + result = json.loads(f.readline()) + return result + return {} + + def _long_poll(self): + notifications = [] + for key in self._cache: + namespace_data = self._cache[key] + notification_id = -1 + if NOTIFICATION_ID in namespace_data: + notification_id = self._cache[key][NOTIFICATION_ID] + notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id}) + try: + # if the length is 0 it is returned directly + if len(notifications) == 0: + return + url = "{}/notifications/v2".format(self.config_url) + params = { + "appId": self.app_id, + "cluster": self.cluster, + "notifications": json.dumps(notifications, ensure_ascii=False), + } + param_str = url_encode_wrapper(params) + url = url + "?" + param_str + code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url)) + http_code = code + if http_code == 304: + logger.debug("No change, loop...") + return + if http_code == 200: + if not body: + logger.error(f"_long_poll load configs failed,body is {body}") + return + data = json.loads(body) + for entry in data: + namespace = entry[NAMESPACE_NAME] + n_id = entry[NOTIFICATION_ID] + logger.info("%s has changes: notificationId=%d", namespace, n_id) + self._get_net_and_set_local(namespace, n_id, call_change=True) + return + else: + logger.warning("Sleep...") + except Exception as e: + logger.warning(str(e)) + + def _get_net_and_set_local(self, namespace, n_id, call_change=False): + namespace_data = self.get_json_from_net(namespace) + if not namespace_data: + return + namespace_data[NOTIFICATION_ID] = n_id + old_namespace = self._cache.get(namespace) + self._update_cache_and_file(namespace_data, namespace) + if self._change_listener is not None and call_change and old_namespace: + old_kv = old_namespace.get(CONFIGURATIONS) + new_kv = namespace_data.get(CONFIGURATIONS) + self._call_listener(namespace, old_kv, new_kv) + + def _listener(self): + logger.info("start long_poll") + while not self._stopping: + self._long_poll() + time.sleep(self._cycle_time) + logger.info("stopped, long_poll") + + # add the need for endorsement to the header + def _sign_headers(self, url: str) -> Mapping[str, str]: + headers: dict[str, str] = {} + if self.secret == "": + return headers + uri = url[len(self.config_url) : len(url)] + time_unix_now = str(int(round(time.time() * 1000))) + headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret) + headers["Timestamp"] = time_unix_now + return headers + + def _heart_beat(self): + while not self._stopping: + for namespace in self._notification_map: + self._do_heart_beat(namespace) + time.sleep(60 * 10) # 10分钟 + + def _do_heart_beat(self, namespace): + url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) + try: + code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) + if code == 200: + if not body: + logger.error(f"_do_heart_beat load configs failed,body is {body}") + return None + data = json.loads(body) + if self.last_release_key == data["releaseKey"]: + return None + self.last_release_key = data["releaseKey"] + data = data["configurations"] + self._update_cache_and_file(data, namespace) + else: + return None + except Exception: + logger.exception("an error occurred in _do_heart_beat") + return None + + def get_all_dicts(self, namespace): + namespace_data = self._cache.get(namespace) + if namespace_data is None: + net_namespace_data = self.get_json_from_net(namespace) + if not net_namespace_data: + return namespace_data + namespace_data = net_namespace_data.get(CONFIGURATIONS) + if namespace_data: + self._update_cache_and_file(namespace_data, namespace) + return namespace_data diff --git a/api/configs/remote_settings_sources/apollo/python_3x.py b/api/configs/remote_settings_sources/apollo/python_3x.py new file mode 100644 index 00000000000000..6a5f3819912206 --- /dev/null +++ b/api/configs/remote_settings_sources/apollo/python_3x.py @@ -0,0 +1,41 @@ +import logging +import os +import ssl +import urllib.request +from urllib import parse +from urllib.error import HTTPError + +# Create an SSL context that allows for a lower level of security +ssl_context = ssl.create_default_context() +ssl_context.set_ciphers("HIGH:!DH:!aNULL") +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + +# Create an opener object and pass in a custom SSL context +opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context)) + +urllib.request.install_opener(opener) + +logger = logging.getLogger(__name__) + + +def http_request(url, timeout, headers={}): + try: + request = urllib.request.Request(url, headers=headers) + res = urllib.request.urlopen(request, timeout=timeout) + body = res.read().decode("utf-8") + return res.code, body + except HTTPError as e: + if e.code == 304: + logger.warning("http_request error,code is 304, maybe you should check secret") + return 304, None + logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg) + raise e + + +def url_encode(params): + return parse.urlencode(params) + + +def makedirs_wrapper(path): + os.makedirs(path, exist_ok=True) diff --git a/api/configs/remote_settings_sources/apollo/utils.py b/api/configs/remote_settings_sources/apollo/utils.py new file mode 100644 index 00000000000000..6136112e03d18e --- /dev/null +++ b/api/configs/remote_settings_sources/apollo/utils.py @@ -0,0 +1,51 @@ +import hashlib +import socket + +from .python_3x import url_encode + +# define constants +CONFIGURATIONS = "configurations" +NOTIFICATION_ID = "notificationId" +NAMESPACE_NAME = "namespaceName" + + +# add timestamps uris and keys +def signature(timestamp, uri, secret): + import base64 + import hmac + + string_to_sign = "" + timestamp + "\n" + uri + hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest() + return base64.b64encode(hmac_code).decode() + + +def url_encode_wrapper(params): + return url_encode(params) + + +def no_key_cache_key(namespace, key): + return "{}{}{}".format(namespace, len(namespace), key) + + +# Returns whether the obtained value is obtained, and None if it does not +def get_value_from_dict(namespace_cache, key): + if namespace_cache: + kv_data = namespace_cache.get(CONFIGURATIONS) + if kv_data is None: + return None + if key in kv_data: + return kv_data[key] + return None + + +def init_ip(): + ip = "" + s = None + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 53)) + ip = s.getsockname()[0] + finally: + if s: + s.close() + return ip diff --git a/api/configs/remote_settings_sources/base.py b/api/configs/remote_settings_sources/base.py new file mode 100644 index 00000000000000..a96ffdfb4bc7df --- /dev/null +++ b/api/configs/remote_settings_sources/base.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic.fields import FieldInfo + + +class RemoteSettingsSource: + def __init__(self, configs: Mapping[str, Any]): + pass + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + raise NotImplementedError + + def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: + return value diff --git a/api/configs/remote_settings_sources/enums.py b/api/configs/remote_settings_sources/enums.py new file mode 100644 index 00000000000000..3081f2950ff707 --- /dev/null +++ b/api/configs/remote_settings_sources/enums.py @@ -0,0 +1,5 @@ +from enum import StrEnum + + +class RemoteSettingsSourceName(StrEnum): + APOLLO = "apollo" diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 05795e11d7dcc5..4500ef4306fc2a 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -14,11 +14,11 @@ if dify_config.ETL_TYPE == "Unstructured": - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"] + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"] DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub")) if dify_config.UNSTRUCTURED_API_URL: DOCUMENT_EXTENSIONS.append("ppt") DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) else: - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 7e1a196356c4e2..c26d8c018610d0 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,8 +1,9 @@ import json +from collections.abc import Mapping from models.model import AppMode -default_app_templates = { +default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index c71f1ce5a31027..9f762b3135e2a4 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -4,3 +4,8 @@ class FilenameNotExistsError(HTTPException): code = 400 description = "The specified filename does not exist." + + +class RemoteFileUploadError(HTTPException): + code = 400 + description = "Error uploading remote file." diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916eda062..b1ebc444a51868 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore parameters__system_parameters = { "image_file_size_limit": fields.Integer, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f46d5b6b138d59..cb6b0d097b1fc9 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,6 +3,25 @@ from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportConfirmApi +from .explore.audio import ChatAudioApi, ChatTextApi +from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from .explore.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from .explore.message import ( + MessageFeedbackApi, + MessageListApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from .explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -66,15 +85,81 @@ # Import explore controllers from .explore import ( - audio, - completion, - conversation, installed_app, - message, parameter, recommended_app, saved_message, - workflow, +) + +# Explore Audio +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") + +# Explore Completion +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) + +# Explore Conversation +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) + + +# Explore Message +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) +# Explore Workflow +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) # Import tag controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a70c4a31c7db94..52e0bb6c56bdc2 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -31,7 +31,7 @@ def decorated(*args, **kwargs): if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if dify_config.ADMIN_API_KEY != auth_token: + if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") return view(*args, **kwargs) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 953770868904d3..ca8ddc32094ac5 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,7 @@ -import flask_restful -from flask_login import current_user +from typing import Any + +import flask_restful # type: ignore +from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal_with from werkzeug.exceptions import Forbidden @@ -35,14 +37,15 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None - token_prefix = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None + token_prefix: str | None = None max_keys = 10 @marshal_with(api_key_list) def get(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( @@ -54,6 +57,7 @@ def get(self, resource_id): @marshal_with(api_key_fields) def post(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_editor: @@ -86,11 +90,12 @@ def post(self, resource_id): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None def delete(self, resource_id, api_key_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa53591..8d0c5b84af5e37 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d4334158945e16..920cae0d859354 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fd05cbc19bf04f..24f1020c18ec37 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -110,7 +110,7 @@ def get(self, app_id): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default=None, type=str) + keyword = request.args.get("keyword", default="", type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index da72b704c71bd7..9cd56cef0b7039 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,8 @@ import uuid from typing import cast -from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 244dcd75de29bc..7e2888d71c79c8 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 695b8890e30f5c..9d26af276d2fc3 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 9896fcaab8ad36..dba41e5c47d24f 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ import logging -import flask_login -from flask_restful import Resource, reqparse +import flask_login # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a25004be4d16ae..8827f129d99317 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -import pytz -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +import pytz # pip install pytz +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -77,8 +77,9 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) + # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -222,7 +223,7 @@ def get(self, app_model): query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -234,7 +235,7 @@ def get(self, app_model): if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( - query.options(joinedload(Conversation.messages)) + query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1f575..c0a20b7160e719 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 9c3cbe4e3e049e..8518d34a8e5af2 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ import os -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a156b80..b5828b6b4b08c4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 8ba195f5a51053..8ecc8a9db5738d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,8 +1,9 @@ import json +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -26,7 +27,9 @@ def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, + config=cast(dict, request.json), + app_mode=AppMode.value_of(app_model.mode), ) new_app_model_config = AppModelConfig( @@ -38,9 +41,11 @@ def post(self, app_model): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = ( + original_app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() ) + if original_app_model_config is None: + raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} @@ -65,7 +70,7 @@ def post(self, app_model): provider_type=agent_tool_entity.provider_type, identity_id=f"AGENT.{app_model.id}", ) - except Exception as e: + except Exception: continue # get decrypted parameters @@ -97,7 +102,7 @@ def post(self, app_model): app_id=app_model.id, agent_tool=agent_tool_entity, ) - except Exception as e: + except Exception: continue manager = ToolParameterConfigurationManager( diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 47b58396a1a303..dd25af8ebf9312 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,5 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore +from werkzeug.exceptions import BadRequest from controllers.console import api from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist @@ -26,7 +27,7 @@ def get(self, app_id): return {"has_not_configured": True} return trace_config except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -48,7 +49,7 @@ def post(self, app_id): raise TracingConfigCheckError() return result except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -68,7 +69,7 @@ def patch(self, app_id): raise TracingConfigNotExist() return {"result": "success"} except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -85,7 +86,7 @@ def delete(self, app_id): raise TracingConfigNotExist() return {"result": "success"} except Exception as e: - raise e + raise BadRequest(str(e)) api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 407f6898199bae..db29b95c4140ff 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language @@ -50,7 +50,7 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() + site = Site.query.filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ "title", diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index db5e2824095ca0..3b21108ceaf76b 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index c85d554069c4ef..26a3a022d401a4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,10 +2,11 @@ import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from configs import dify_config from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model @@ -426,7 +427,21 @@ def post(self, app_model: App): } +class WorkflowConfigApi(Resource): + """Resource for workflow configuration.""" + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(WorkflowConfigApi, "/apps//workflows/draft/config") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 2940556f84ef4e..882c53e4fb9972 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9c97e..25a99c1e1594ae 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707bb204..097bf7d1888cf5 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c71ee8e5dfea1d..9ad8c158473df9 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,11 +5,10 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models import App -from models.model import AppMode +from models import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index d2aa7c903b046c..c56f551d49be8b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,14 +1,14 @@ import datetime from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -27,7 +27,7 @@ def get(self): invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) - tenant: Tenant = invitation.get("tenant", None) + tenant = invitation.get("tenant", None) workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 465c44e9b6dc2f..ea00c2b8c2272c 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index faca67bb177f10..e911c9a5e5b5ea 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,8 +2,8 @@ import requests from flask import current_app, redirect, request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,8 +17,8 @@ def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( - client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, + client_id=dify_config.NOTION_CLIENT_ID or "", + client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index fb32bb2b60286d..140b9e145fa9cd 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,7 @@ import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api @@ -122,8 +122,8 @@ def post(self): else: try: account = AccountService.create_account_and_tenant( - email=reset_data.get("email"), - name=reset_data.get("email"), + email=reset_data.get("email", ""), + name=reset_data.get("email", ""), password=password_confirm, interface_language=languages[0], ) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f4463ce9cb3f30..78a80fc8d7e075 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,8 +1,8 @@ from typing import cast -import flask_login +import flask_login # type: ignore from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore import services from constants.languages import languages diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5de8c6766d35a8..333b24142727f0 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -76,8 +76,9 @@ def get(self, provider: str): try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) - except requests.exceptions.HTTPError as e: - logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + except requests.exceptions.RequestException as e: + error_text = e.response.text if e.response else str(e) + logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -129,7 +130,7 @@ def get(self, provider: str): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account = Account.get_by_openid(provider, user_info.id) + account: Optional[Account] = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c90c2..fd7b7bd8cb3ddd 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 278295ca39a696..d7c431b95080da 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -2,8 +2,8 @@ import json from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3a8f27..f3c3736b25acc5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restful # type: ignore from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index de3b4f62620501..ca41e504be7eda 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,12 +1,13 @@ import logging from argparse import ArgumentTypeError from datetime import UTC, datetime +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import asc, desc -from transformers.hf_argparser import string_to_bool +from transformers.hf_argparser import string_to_bool # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services @@ -733,8 +734,7 @@ def put(self, dataset_id, document_id): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -948,7 +948,7 @@ def post(self, dataset_id): if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) - except Exception as e: + except Exception: logging.exception(f"Failed to retry document, document id: {document_id}") continue # retry document diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 6f7ef86d2c3fd3..2d5933ca23609a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,8 +3,8 @@ import pandas as pd from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index bc6e3687c1c99d..48f360dcd179bc 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 495f511275b4b9..18b746f547287c 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c07686361d0..bd944602c147cb 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 9127c8af455f6c..da995537e74753 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 9690677f61b1c2..c7f9fec326945f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -4,7 +4,6 @@ from werkzeug.exceptions import InternalServerError import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -67,7 +66,7 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore app_model = installed_app.app try: @@ -118,9 +117,3 @@ def post(self, installed_app): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") -# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', -# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 85c43f8101028e..3331ded70f6620 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,12 +1,11 @@ import logging from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import reqparse +from flask_login import current_user # type: ignore +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -147,21 +146,3 @@ def post(self, installed_app, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 - - -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6f9d7769b942ce..91916cbc1ed85f 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,12 +1,13 @@ -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -34,14 +35,16 @@ def get(self, installed_app): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=current_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.EXPLORE, + pinned=pinned, + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -114,28 +117,3 @@ def patch(self, installed_app, c_id): WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} - - -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index b60c4e176b3214..86550b2bdf44b9 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,7 +1,9 @@ from datetime import UTC, datetime +from typing import Any -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask import request +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -20,11 +22,20 @@ class InstalledAppsListApi(Resource): @account_initialization_required @marshal_with(installed_app_list_fields) def get(self): + app_id = request.args.get("app_id", default=None, type=str) current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() + + if app_id: + installed_apps = ( + db.session.query(InstalledApp) + .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) + .all() + ) + else: + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ + installed_app_list: list[dict[str, Any]] = [ { "id": installed_app.id, "app": installed_app.app, @@ -37,7 +48,7 @@ def get(self): for installed_app in installed_apps if installed_app.app is not None ] - installed_apps.sort( + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], app["last_used_at"] is None, @@ -45,7 +56,7 @@ def get(self): ) ) - return {"installed_apps": installed_apps} + return {"installed_apps": installed_app_list} @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3d221ff30a6599..c3488de29929c9 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -70,7 +69,7 @@ def post(self, installed_app, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -153,21 +152,3 @@ def get(self, installed_app, message_id): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fee52248a698e0..5bc74d16e784af 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 5daaa1e7c38ba8..be6b1f5d215fb4 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,9 +1,10 @@ -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required +from libs.helper import AppIconUrlField from libs.login import login_required from services.recommended_app_service import RecommendedAppService @@ -12,6 +13,8 @@ "name": fields.String, "mode": fields.String, "icon": fields.String, + "icon_type": fields.String, + "icon_url": AppIconUrlField, "icon_background": fields.String, } diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 0fc963747981e1..9f0c4966457186 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 45f99b1db9fa9e..76d30299cd84a7 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,9 +1,8 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError -from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -73,9 +72,3 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"} - - -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0f86d..b7ba81fba20f79 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4ac0aa497e0866..ed6cedb220cf4b 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865cb48..da1171412fdb2d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 946d3db37f587b..8cf754bbd686fd 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,6 +1,9 @@ +from typing import Literal + from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with # type: ignore +from werkzeug.exceptions import Forbidden import services from configs import dify_config @@ -47,7 +50,8 @@ def get(self): @cloud_edition_billing_resource_check("documents") def post(self): file = request.files["file"] - source = request.form.get("source") + source_str = request.form.get("source") + source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() @@ -58,6 +62,9 @@ def post(self): if not file.filename: raise FilenameNotExistsError + if source == "datasets" and not current_user.is_dataset_editor: + raise Forbidden() + if source not in ("datasets", None): source = None diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index ae759bb752a30e..d9ae5cf29fc626 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946ee288..2a116112a3227c 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fac1341b39b596..30afc930a8e980 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,11 +2,12 @@ from typing import cast import httpx -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields @@ -43,10 +44,14 @@ def post(self): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e0b728d97739d3..aba6f0aad9ee54 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ccd3293a6266fc..da83f64019161b 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -23,7 +23,7 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get("type", type=str) + tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 7dea8e554edd7a..7773c99944e42c 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import logging import requests -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index f704783cfff56b..96ed4b7a570256 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,8 +2,8 @@ import pytz from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from configs import dify_config from constants.languages import supported_language diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index d2b2092b75a9ff..7009343d9923da 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -37,7 +37,7 @@ def post(self, provider: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( @@ -86,7 +86,7 @@ def post(self, provider: str, config_id: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 38ed2316a58935..1afb41ea87660c 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,7 @@ from urllib import parse -from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore import services from configs import dify_config @@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource): @account_initialization_required def delete(self, member_id): member = db.session.query(Account).filter(Account.id == str(member_id)).first() - if not member: + if member is None: abort(404) - - try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) - except services.errors.account.CannotOperateSelfError as e: - return {"code": "cannot-operate-self", "message": str(e)}, 400 - except services.errors.account.NoPermissionError as e: - return {"code": "forbidden", "message": str(e)}, 403 - except services.errors.account.MemberNotInTenantError as e: - return {"code": "member-not-found", "message": str(e)}, 404 - except Exception as e: - raise ValueError(str(e)) + else: + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {"code": "cannot-operate-self", "message": str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {"code": "forbidden", "message": str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {"code": "member-not-found", "message": str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) return {"result": "success"}, 204 @@ -122,10 +122,11 @@ def put(self, member_id): return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) - if not member: + if member: abort(404) try: + assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0e54126063be75..2d11295b0fdf61 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -66,7 +66,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.provider_credentials_validate( @@ -132,7 +132,8 @@ def get(self, provider: str, icon_type: str, lang: str): icon_type=icon_type, lang=lang, ) - + if icon is None: + raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}") return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f804285f008120..618262e502ab33 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -308,7 +308,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.model_credentials_validate( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 9ecda2126d4b42..964f3862291a2e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,14 +1,16 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -91,12 +93,16 @@ def post(self, provider): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user_id = current_user.id tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) @@ -368,6 +372,7 @@ def post(self): description=args["description"], parameters=args["parameters"], privacy_policy=args["privacy_policy"], + labels=args["labels"], ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 76d76f6b58fc3c..0f99bf62e3c251 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,8 @@ import logging from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Unauthorized import services @@ -82,11 +82,7 @@ def get(self): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = ( - db.session.query(Tenant) - .order_by(Tenant.created_at.desc()) - .paginate(page=args["page"], per_page=args["limit"]) - ) + tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"]) has_more = False if len(tenants.items) == args["limit"]: @@ -151,6 +147,8 @@ def post(self): raise AccountNotLinkTenantError("Account not link tenant") new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + if new_tenant is None: + raise ValueError("Tenant not found") return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -166,7 +164,7 @@ def post(self): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d0df296c240686..111db7ccf2da04 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -3,7 +3,7 @@ from functools import wraps from flask import abort, request -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError @@ -121,8 +121,8 @@ def decorated(*args, **kwargs): utm_info = request.cookies.get("utm_info") if utm_info: - utm_info = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info) + utm_info_dict: dict = json.loads(utm_info) + OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) except Exception as e: pass return view(*args, **kwargs) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 6b3ac93cdf3d8f..2357288a50ae36 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,5 +1,5 @@ from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound import services diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index a298701a2f8b11..cfcce8124761f5 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,5 +1,5 @@ from flask import Response -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 99d32af593991f..d7346b13b10a90 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 51ffe683ff40ad..d4587235f6aef8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -45,14 +45,14 @@ def decorated(*args, **kwargs): if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key") + inner_api_key = request.headers.get("X-Inner-Api-Key", "") data_to_sign = f"DIFY {user_id}" signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) - signature = b64encode(signature.digest()).decode("utf-8") + signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature != token: + if signature_base64 != token: return view(*args, **kwargs) kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecff7d07e974d9..8388e2045dd34f 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 5db41636471220..e6bcc0bfd25355 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -83,7 +83,7 @@ def post(self, app_model: App, end_user: EndUser): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 8d8e356c4cb940..1be54b386bfe8c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c62fd77d367aa6..334f2c56206794 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound import services @@ -7,6 +8,7 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import ( conversation_delete_fields, conversation_infinite_scroll_pagination_fields, @@ -39,14 +41,16 @@ def get(self, app_model: App, end_user: EndUser): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef97df..27b21b9f505633 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ada40ec9cb26bd..522c7509b9849d 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -104,10 +104,11 @@ def post(self, app_model: App, end_user: EndUser, message_id): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 96d1337632826a..c7dd4de3452970 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError from controllers.service_api import api diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 799fccc228e21d..d6a3beb6b80b9d 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound import services.dataset_service diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5c3fc7b241175a..34afe2837f4ca5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,7 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from sqlalchemy import desc from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e68f6b4dc40a36..34904574a8b88d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.service_api import api diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index d24c4597e210fb..75d9141a6d0a3a 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from configs import dify_config from controllers.service_api import api diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2128c4c53f9909..740b92ef8e4faf 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -5,8 +5,8 @@ from typing import Optional from flask import current_app, request -from flask_login import user_logged_in -from flask_restful import Resource +from flask_login import user_logged_in # type: ignore +from flask_restful import Resource # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden, Unauthorized @@ -49,6 +49,8 @@ def decorated_view(*args, **kwargs): raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant is None: + raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") @@ -154,8 +156,8 @@ def decorated(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cc8255ccf4e748..20e071c834ad5b 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e8521307ad357a..97d980d07c13a7 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore try: parser = reqparse.RequestParser() @@ -82,7 +82,7 @@ def post(self, app_model: App, end_user): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 45b890dfc4899d..761771a81a4bb3 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c3b0cd4f44b2ac..28feb1ca47effd 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,11 +1,13 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -40,15 +42,17 @@ def get(self, app_model, end_user): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, + sort_by=args["sort_by"], + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed22382e6b..ce841a8814972d 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index a282fc63a8b056..1d4474015ab648 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError @@ -33,7 +33,7 @@ def post(self, app_model, end_user): content=file.read(), mimetype=file.mimetype, user=end_user, - source=source, + source="datasets" if source == "datasets" else None, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 98891f5d00d7e0..0f47e643708570 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -108,7 +108,7 @@ def post(self, app_model, end_user, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a01ffd861230a5..4625c1f43dfbd1 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,7 +1,7 @@ import uuid from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index d6b8eb2855725c..d559ab8e07e736 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,10 +1,11 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse # type: ignore import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -38,10 +39,14 @@ def post(self, app_model, end_user): # Add app_model and end_user parameters url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b0492e6b6f0d31..6a9b8189076c3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea39855..e68dc7aa4afba5 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 55b0c3e2ab34c5..48d25e720c10c3 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18526c..1b4d263bee4401 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebSSOAuthRequiredError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ead293200ea3aa..8d69bdcec2c2ac 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,6 @@ import json import logging import uuid -from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Optional, Union, cast @@ -53,6 +52,7 @@ class BaseAgentRunner(AppRunner): def __init__( self, + *, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, @@ -66,7 +66,7 @@ def __init__( prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance | None = None, + model_instance: ModelInstance, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -117,7 +117,7 @@ def __init__( features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query = None + self.query: Optional[str] = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( @@ -145,7 +145,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P message_tool = PromptMessageTool( name=tool.tool_name, - description=tool_entity.description.llm, + description=tool_entity.description.llm if tool_entity.description else "", parameters={ "type": "object", "properties": {}, @@ -167,7 +167,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -187,8 +187,8 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe convert dataset retriever tool to prompt message tool """ prompt_tool = PromptMessageTool( - name=tool.identity.name, - description=tool.description.llm, + name=tool.identity.name if tool.identity else "unknown", + description=tool.description.llm if tool.description else "", parameters={ "type": "object", "properties": {}, @@ -210,14 +210,14 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe return prompt_tool - def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: """ Init tools """ tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools if self.app_config.agent else []: + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -234,7 +234,8 @@ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessage # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + if dataset_tool.identity is not None: + tool_instances[dataset_tool.identity.name] = dataset_tool return tool_instances, prompt_messages_tools @@ -258,7 +259,7 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -322,16 +323,21 @@ def save_agent_thought( tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], answer: str, messages_ids: list[str], - llm_usage: LLMUsage = None, - ) -> MessageAgentThought: + llm_usage: LLMUsage | None = None, + ): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + queried_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not queried_thought: + raise ValueError(f"Agent thought {agent_thought.id} not found") + agent_thought = queried_thought if thought is not None: agent_thought.thought = thought @@ -404,7 +410,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab """ convert tool variables to db variables """ - db_variables = ( + queried_variables = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == self.message.conversation_id, @@ -412,6 +418,11 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab .first() ) + if not queried_variables: + return + + db_variables = queried_variables + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() @@ -421,7 +432,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P """ Organize agent history """ - result = [] + result: list[PromptMessage] = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d98ba5a3fad846..e936acb6055cb8 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Generator, Mapping +from typing import Any, Optional from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) @@ -26,18 +27,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] = None - _agent_scratchpad: list[AgentScratchpadUnit] = None - _instruction: str = None - _query: str = None - _prompt_messages_tools: list[PromptMessage] = None + _historic_prompt_messages: list[PromptMessage] | None = None + _agent_scratchpad: list[AgentScratchpadUnit] | None = None + _instruction: str = "" # FIXME this must be str for now + _query: str | None = None + _prompt_messages_tools: list[PromptMessageTool] = [] def run( self, message: Message, query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + inputs: Mapping[str, str], + ) -> Generator: """ Run Cot agent application """ @@ -57,19 +58,19 @@ def run( # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 # convert tools into ModelRuntime Tool format tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -90,7 +91,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids @@ -105,7 +106,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, tools=[], @@ -115,11 +116,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): callbacks=[], ) + if not isinstance(chunks, Generator): + raise ValueError("Expected streaming response from LLM") + # check llm result if not chunks: raise ValueError("failed to invoke llm") - usage_dict = {} + usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -139,25 +143,30 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk # detect action - scratchpad.agent_response += json.dumps(chunk.model_dump()) + if scratchpad.agent_response is not None: + scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action = action else: - scratchpad.agent_response += chunk - scratchpad.thought += chunk + if scratchpad.agent_response is not None: + scratchpad.agent_response += chunk + if scratchpad.thought is not None: + scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint="", delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) + if scratchpad.thought is not None: + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" + if self._agent_scratchpad is not None: + self._agent_scratchpad.append(scratchpad) # get llm usage if "usage" in usage_dict: - increase_usage(llm_usage, usage_dict["usage"]) + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() @@ -166,9 +175,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation="", - answer=scratchpad.agent_response, + answer=scratchpad.agent_response or "", messages_ids=[], llm_usage=usage_dict["usage"], ) @@ -209,7 +218,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): agent_thought=agent_thought, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation={scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, @@ -247,8 +256,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=final_answer, messages_ids=[], ) - - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool is not None and self.db_variables_pool is not None: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -307,8 +316,9 @@ def _handle_invoke_action( # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if save_as is not None and self.variables_pool: + # FIXME the save_as type is confusing, it should be a string or not + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) # publish message file self.queue_manager.publish( @@ -325,7 +335,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: """ fill in inputs from external data tools """ @@ -376,11 +386,13 @@ def _organize_historic_prompt_messages( """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit = None + current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: + if not isinstance(message.content, str | None): + raise NotImplementedError("expected str type") current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content or "I am thinking about how to help you", @@ -399,8 +411,12 @@ def _organize_historic_prompt_messages( except: pass elif isinstance(message, ToolPromptMessage): - if current_scratchpad: + if not current_scratchpad: + continue + if isinstance(message.content, str): current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") elif isinstance(message, UserPromptMessage): if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index d8d047fe91cdbd..6a96c349b2611c 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,7 +19,12 @@ def _organize_system_prompt(self) -> SystemPromptMessage: """ Organize system prompt """ + if not self.app_config.agent: + raise ValueError("Agent configuration is not set") + prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -75,6 +80,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: assistant_messages = [] else: assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 0563090537e62c..3a4d31e047f5ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -2,7 +2,12 @@ from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder @@ -11,7 +16,11 @@ def _organize_instruction_prompt(self) -> str: """ Organize instruction prompt """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -33,7 +42,13 @@ def _organize_historic_prompt(self, current_session_messages: Optional[list[Prom if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" elif isinstance(message, AssistantPromptMessage): - historic_prompt += message.content + "\n\n" + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt @@ -50,7 +65,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" - for unit in agent_scratchpad: + for unit in agent_scratchpad or []: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" else: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 119a88fc7becbf..2ae87dca3f8cbd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -78,5 +78,5 @@ class Strategy(Enum): model: str strategy: Strategy prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index cd546dee124147..b862c96072aaa0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -40,6 +40,8 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul app_generate_entity = self.application_generate_entity app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -49,7 +51,7 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul # continue to run until there is not any tool call function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} final_answer = "" # get tracing instance @@ -75,7 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -105,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): current_llm_usage = None - if self.stream_tool_call: + if self.stream_tool_call and isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: if is_first_chunk: @@ -116,7 +118,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -131,19 +133,19 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in chunk.delta.message.content: response += content.data else: - response += chunk.delta.message.content + response += str(chunk.delta.message.content) if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) current_llm_usage = chunk.delta.usage yield chunk - else: - result: LLMResult = chunks + elif not self.stream_tool_call and isinstance(chunks, LLMResult): + result = chunks # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result)) + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -162,7 +164,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in result.message.content: response += content.data else: - response += result.message.content + response += str(result.message.content) if not result.message.content: result.message.content = "" @@ -181,6 +183,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): usage=result.usage, ), ) + else: + raise RuntimeError(f"invalid chunks type: {type(chunks)}") assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: @@ -243,7 +247,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if self.variables_pool: + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file_id, name=save_as + ) # publish message file self.queue_manager.publish( @@ -263,7 +270,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response["tool_response"], + content=str(tool_response["tool_response"]), tool_call_id=tool_call_id, name=tool_call_name, ) @@ -273,9 +280,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name=None, - tool_input=None, - thought=None, + tool_name="", + tool_input="", + thought="", tool_invoke_meta={ tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, @@ -283,7 +290,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, - answer=None, + answer="", messages_ids=message_file_ids, ) self.queue_manager.publish( @@ -296,7 +303,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool and self.db_variables_pool: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -389,9 +397,9 @@ def _init_system_message( if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages + return prompt_messages or [] - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ @@ -449,7 +457,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query, []) + query_prompt_messages = self._organize_user_query(self.query or "", []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 085bac8601b2da..61fa774ea5f390 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -38,7 +38,7 @@ def parse_action(json_str): except: return json_str or "" - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return @@ -67,15 +67,15 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage - response = response.delta.message.content - if not isinstance(response, str): + response_content = response.delta.message.content + if not isinstance(response_content, str): continue # stream index = 0 - while index < len(response): + while index < len(response_content): steps = 1 - delta = response[index : index + steps] + delta = response_content[index : index + steps] yield_delta = False if delta == "`": diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index b9aae7904f5e7c..646c4badb9f73a 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -66,6 +66,8 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: dataset_configs = config.get("dataset_configs") else: dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None query_variable = config.get("dataset_query_variable") if dataset_configs["retrieval_model"] == "single": diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5adcf26f1486e8..6426865115126f 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -94,7 +94,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> config["model"]["completion_params"] ) - return config, ["model"] + return dict(config), ["model"] @classmethod def validate_model_completion_params(cls, cp: dict) -> dict: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index b4dacbc409044a..92b4185abf0183 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -7,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get("opening_statement") + opening_statement = config.get("opening_statement", "") # suggested questions - suggested_questions_list = config.get("suggested_questions") + suggested_questions_list = config.get("suggested_questions", []) return opening_statement, suggested_questions_list diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index bd4fd9cd3b2646..a18b40712b7ce6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -29,6 +29,7 @@ from models.account import Account from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -36,6 +37,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): _dialogue_count: int + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload def generate( self, app_model: App, @@ -44,7 +68,17 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ Generate App response. @@ -112,7 +146,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -280,6 +314,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 18b115dfe40d3c..a506447671abfb 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -4,14 +4,19 @@ import queue import re import threading +from collections.abc import Iterable +from typing import Optional from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentMessageEvent, QueueLLMChunkEvent, QueueNodeSucceededEvent, QueueTextChunkEvent, + WorkflowQueueMessage, ) -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.model_runtime.entities.model_entities import ModelType @@ -21,7 +26,7 @@ def __init__(self, status: str, audio): self.status = status -def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -29,13 +34,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): ) -def _process_future(future_queue, audio_queue): +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): while True: try: future = future_queue.get() if future is None: break - for audio in future.result(): + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: audio_base64 = base64.b64encode(bytes(audio)) audio_queue.put(AudioTrunk("responding", audio=audio_base64)) except Exception as e: @@ -49,8 +60,8 @@ def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" - self._audio_queue = queue.Queue() - self._msg_queue = queue.Queue() + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( @@ -62,18 +73,16 @@ def __init__(self, tenant_id: str, voice: str): if not voice or voice not in values: self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 - self._last_audio_event = None - self._runtime_thread = threading.Thread(target=self._runtime).start() + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) - def publish(self, message): - try: - self._msg_queue.put(message) - except Exception as e: - self.logger.warning(e) + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) def _runtime(self): - future_queue = queue.Queue() + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() while True: try: @@ -86,10 +95,21 @@ def _runtime(self): future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - self.msg_text += message.event.chunk.delta.message.content + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) @@ -110,16 +130,15 @@ def _runtime(self): break future_queue.put(None) - def check_and_get_audio(self) -> AudioTrunk | None: + def check_and_get_audio(self): try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: self.executor.shutdown(wait=False) - return self.last_message + return self._last_audio_event audio = self._audio_queue.get_nowait() if audio and audio.status == "finish": self.executor.shutdown(wait=False) - self._runtime_thread = None if audio: self._last_audio_event = audio return audio diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index cf0c9d7593429a..6339d798984800 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -109,18 +109,18 @@ def run(self) -> None: ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: # Create conversation variables if they don't exist. - conversation_variables = [ + db_conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] - session.add_all(conversation_variables) + session.add_all(db_conversation_variables) # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] + conversation_variables = [item.to_variable() for item in db_conversation_variables] session.commit() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index cd12690e286103..1073a0f2e4f706 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator, Mapping +from threading import Thread from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -19,8 +20,10 @@ QueueIterationNextEvent, QueueIterationStartEvent, QueueMessageReplaceEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -31,6 +34,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -61,6 +65,7 @@ from models.workflow import ( Workflow, WorkflowNodeExecution, + WorkflowRun, WorkflowRunStatus, ) @@ -78,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + _conversation_name_generate_thread: Optional[Thread] = None def __init__( self, @@ -127,9 +133,8 @@ def __init__( self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] - self.total_tokens: int = 0 - def process(self): + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -178,7 +183,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -195,11 +200,11 @@ def _to_stream_response( stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -220,7 +225,7 @@ def _wrapper_process_stream_response( for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -260,8 +265,8 @@ def _process_stream_response( :return: """ # init fake graph runtime state - graph_runtime_state = None - workflow_run = None + graph_runtime_state: Optional[GraphRuntimeState] = None + workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -289,20 +294,38 @@ def _process_stream_response( yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not workflow_run: + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response = self._workflow_node_start_to_stream_response( + response_start = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_start: + yield response_start elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -310,18 +333,18 @@ def _process_stream_response( if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + if response_finish: + yield response_finish + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -329,54 +352,53 @@ def _process_stream_response( if response: yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") - # FIXME for issue #11221 quick fix maybe have a better solution - self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0 yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("workflow run not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens or self.total_tokens, + total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, conversation_id=self._conversation.id, @@ -387,13 +409,36 @@ def _process_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not workflow_run: + raise ValueError("workflow run not initialized.") + + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_partial_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -404,6 +449,7 @@ def _process_stream_response( error=event.error, conversation_id=self._conversation.id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) yield self._workflow_finish_to_stream_response( @@ -471,7 +517,7 @@ def _process_stream_response( # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._message_to_stream_response( @@ -482,7 +528,7 @@ def _process_stream_response( yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: @@ -566,7 +612,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 417d23eccfb553..55b6ee510f228c 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -61,7 +61,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b659c1855624b5..63e11bdaa27f74 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -2,7 +2,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -23,11 +23,45 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + def generate( self, *, @@ -36,7 +70,7 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ): """ Generate App response. @@ -64,7 +98,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -120,7 +154,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -147,7 +181,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -166,8 +200,8 @@ def generate( user=user, stream=streaming, ) - - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore def _generate_worker( self, @@ -191,6 +225,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 45b1bf00934d35..ac71f02b6de03d 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -173,6 +173,8 @@ def run( return agent_entity = app_config.agent + if not agent_entity: + raise ValueError("Agent entity not found") # load tool variables tool_conversation_variables = self._load_tool_variables( @@ -200,14 +202,21 @@ def run( # change function call strategy based on LLM model llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema or not model_schema.features: + raise ValueError("Model schema not found") if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - message = db.session.query(Message).filter(Message.id == message.id).first() + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") db.session.close() + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode @@ -225,12 +234,12 @@ def run( runner = runner_cls( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=conversation_result, app_config=app_config, model_config=application_generate_entity.model_conf, config=agent_entity, queue_manager=queue_manager, - message=message, + message=message_result, user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, @@ -257,7 +266,7 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st """ load tool variables from database """ - tool_variables: ToolConversationVariables = ( + tool_variables: ToolConversationVariables | None = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == conversation_id, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 629c309c065458..ce331d904cc826 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,8 +51,9 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR return response @classmethod - def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_full_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream full response. @@ -82,8 +83,9 @@ def convert_stream_full_response( yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_simple_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 210609b504b9c7..be4027132ba903 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -82,7 +82,7 @@ def _get_simple_metadata(cls, metadata: dict[str, Any]): for resource in metadata["retriever_resources"]: updated_resources.append( { - "segment_id": resource["segment_id"], + "segment_id": resource.get("segment_id", ""), "position": resource["position"], "document_name": resource["document_name"], "score": resource["score"], diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 4c4d282e99b6ae..1842fc43033ab8 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,6 @@ import queue import time from abc import abstractmethod -from collections.abc import Generator from enum import Enum from typing import Any @@ -11,9 +10,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueErrorEvent, QueuePingEvent, QueueStopEvent, + WorkflowQueueMessage, ) from extensions.ext_redis import redis_client @@ -37,11 +38,11 @@ def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) - q = queue.Queue() + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q - def listen(self) -> Generator: + def listen(self): """ Listen to queue :return: @@ -49,7 +50,7 @@ def listen(self) -> Generator: # wait for APP_MAX_EXECUTION_TIME seconds to stop listen listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() - last_ping_time = 0 + last_ping_time: int | float = 0 while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 609fd03f229da8..07a248d77aee86 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -36,8 +36,8 @@ def get_pre_calculate_rest_tokens( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, ) -> int: """ @@ -64,7 +64,7 @@ def get_pre_calculate_rest_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -85,7 +85,7 @@ def get_pre_calculate_rest_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - rest_tokens = model_context_tokens - max_tokens - prompt_tokens + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise InvokeBadRequestError( "Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -111,7 +111,7 @@ def recalc_llm_max_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -136,8 +136,8 @@ def organize_prompt_messages( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, @@ -156,6 +156,7 @@ def organize_prompt_messages( """ # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( app_mode=AppMode.value_of(app_record.mode), @@ -171,8 +172,11 @@ def organize_prompt_messages( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: @@ -181,6 +185,8 @@ def organize_prompt_messages( assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) @@ -246,7 +252,7 @@ def direct_output( def _handle_invoke_result( self, - invoke_result: Union[LLMResult, Generator], + invoke_result: Union[LLMResult, Generator[Any, None, None]], queue_manager: AppQueueManager, stream: bool, agent: bool = False, @@ -259,10 +265,12 @@ def _handle_invoke_result( :param agent: agent :return: """ - if not stream: + if not stream and isinstance(invoke_result, LLMResult): self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - else: + elif stream and isinstance(invoke_result, Generator): self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool @@ -291,8 +299,8 @@ def _handle_invoke_result_stream( :param agent: agent :return: """ - model = None - prompt_messages = [] + model: str = "" + prompt_messages: list[PromptMessage] = [] text = "" usage = None for result in invoke_result: @@ -328,13 +336,14 @@ def _handle_invoke_result_stream( def moderation_for_inputs( self, + *, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, inputs: Mapping[str, Any], - query: str, + query: str | None = None, message_id: str, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -350,7 +359,7 @@ def moderation_for_inputs( app_id=app_id, tenant_id=tenant_id, app_config=app_generate_entity.app_config, - inputs=inputs, + inputs=dict(inputs), query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, @@ -390,9 +399,9 @@ def fill_in_inputs_from_external_data_tools( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 6a9e1623881e17..6ed71fcd843083 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,7 +1,7 @@ import logging import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app @@ -24,6 +24,7 @@ from factories import file_factory from models.account import Account from models.model import App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -34,9 +35,9 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -44,19 +45,29 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ): """ Generate App response. @@ -81,7 +92,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -94,7 +105,7 @@ def generate( # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # always enable retriever resource in debugger mode @@ -136,7 +147,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, @@ -162,7 +173,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -206,6 +217,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 0fa7af0a7fa36d..9024c3a98273d1 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -52,7 +52,8 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -83,7 +84,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 1193c4b7a43632..02e5d475684cdc 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -42,7 +42,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 324e837a1c8f29..17d0d52497ceee 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,7 +1,7 @@ import logging import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app @@ -34,9 +34,9 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -44,14 +44,29 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... def generate( - self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True - ) -> Union[dict, Generator[str, None, None]]: + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ Generate App response. @@ -68,8 +83,6 @@ def generate( query = query.replace("\x00", "") inputs = args["inputs"] - extras = {} - # get conversation conversation = None @@ -84,7 +97,7 @@ def generate( # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # parse files @@ -117,11 +130,11 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras=extras, + extras={}, trace_manager=trace_manager, ) @@ -142,7 +155,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, @@ -182,6 +195,8 @@ def _generate_worker( try: # get message message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() @@ -216,7 +231,7 @@ def generate_more_like_this( user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: """ Generate App response. @@ -278,7 +293,7 @@ def generate_more_like_this( model_conf=ModelConfigConverter.convert(app_config), inputs=message.inputs, query=message.query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=stream, invoke_from=invoke_from, @@ -302,7 +317,7 @@ def generate_more_like_this( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 908d74ff539a5a..41278b75b42bf4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -76,7 +76,7 @@ def run( tenant_id=app_config.tenant_id, app_generate_entity=application_generate_entity, inputs=inputs, - query=query, + query=query or "", message_id=message.id, ) except ModerationError as e: @@ -122,7 +122,7 @@ def run( tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_conf, config=dataset_config, - query=query, + query=query or "", invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 697f0273a5673e..73f38c3d0bcb96 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,7 +51,8 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki @classmethod def convert_stream_full_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -81,7 +82,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 95ae798ec1ac74..c2e35faf89ba15 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,11 +2,11 @@ import logging from collections.abc import Generator from datetime import UTC, datetime -from typing import Optional, Union +from typing import Optional, Union, cast from sqlalchemy import and_ -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( @@ -42,7 +42,7 @@ def _handle_response( ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -144,7 +144,7 @@ def _init_generate_records( :conversation conversation :return: """ - app_config = application_generate_entity.app_config + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) # get from source end_user_id = None @@ -267,7 +267,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat except KeyError: pass - return introduction + return introduction or "" def _get_conversation(self, conversation_id: str): """ @@ -282,7 +282,7 @@ def _get_conversation(self, conversation_id: str): return conversation - def _get_message(self, message_id: str) -> Message: + def _get_message(self, message_id: str) -> Optional[Message]: """ Get message by message id :param message_id: message id diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7acf05326efe3f..1d5f21b9e0cc07 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -30,6 +30,35 @@ class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any]: ... + + @overload def generate( self, *, @@ -41,7 +70,20 @@ def generate( streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files @@ -74,7 +116,7 @@ def generate( inputs=self._prepare_user_inputs( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), - files=system_files, + files=list(system_files), user_id=user.id, stream=streaming, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 76371f800ba1e5..349b8eb51b1546 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -6,6 +6,7 @@ QueueMessageEndEvent, QueueStopEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, ) @@ -34,7 +35,8 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | QueueErrorEvent | QueueMessageEndEvent | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent, + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, ): self.stop_listen() diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 08d00ee1805aa2..5cdac6ad28fdaa 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.to_dict() + return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -36,7 +36,8 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking @classmethod def convert_stream_full_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -65,7 +66,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9966a1a9d13cf5..c47b38f5600f4d 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -15,8 +15,10 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -26,6 +28,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -106,7 +109,6 @@ def __init__( self._task_state = WorkflowTaskState() self._wip_workflow_node_executions = {} - self.total_tokens: int = 0 def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -153,7 +155,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -169,11 +171,11 @@ def _to_stream_response( yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -194,7 +196,7 @@ def _wrapper_process_stream_response( for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -216,7 +218,7 @@ def _wrapper_process_stream_response( break else: yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) - except Exception as e: + except Exception: logger.exception(f"Fails to get audio trunk, task_id: {task_id}") break if tts_publisher: @@ -252,90 +254,106 @@ def _process_stream_response( yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not workflow_run: + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response = self._workflow_node_start_to_stream_response( + node_start_response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if node_start_response: + yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) - response = self._workflow_node_finish_to_stream_response( + node_success_response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + if node_success_response: + yield node_success_response + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response = self._workflow_node_finish_to_stream_response( + node_failed_response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) + if node_failed_response: + yield node_failed_response - if response: - yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") - # FIXME for issue #11221 quick fix maybe have a better solution - self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0 yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens or self.total_tokens, + total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, conversation_id=None, @@ -348,13 +366,36 @@ def _process_stream_response( yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") + workflow_run = self._handle_workflow_run_partial_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise ValueError("workflow run not initialized.") + + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, @@ -366,6 +407,7 @@ def _process_stream_response( error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, ) # save workflow app log @@ -381,7 +423,7 @@ def _process_stream_response( # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 3d46b8bab03e17..63f516bcc60682 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -8,8 +8,10 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -18,13 +20,16 @@ QueueRetrieverResourcesEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -32,8 +37,10 @@ IterationRunStartedEvent, IterationRunSucceededEvent, NodeInIterationFailedEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -176,8 +183,46 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunPartialSucceededEvent): + self._publish_event( + QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count) + ) elif isinstance(event, GraphRunFailedEvent): - self._publish_event(QueueWorkflowFailedEvent(error=event.error)) + self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) + elif isinstance(event, NodeRunRetryEvent): + node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + inputs=inputs, + process_data=process_data, + outputs=outputs, + error=event.error, + execution_metadata=execution_metadata, + retry_index=event.retry_index, + ) + ) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -197,8 +242,38 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) ) elif isinstance(event, NodeRunSucceededEvent): + node_run_result = event.route_node_state.node_run_result + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + else: + inputs = {} + process_data = {} + outputs = {} + execution_metadata = {} self._publish_event( QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( node_execution_id=event.id, node_id=event.node_id, node_type=event.node_type, @@ -214,18 +289,21 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", execution_metadata=event.route_node_state.node_run_result.metadata if event.route_node_state.node_run_result else {}, in_iteration_id=event.in_iteration_id, ) ) - elif isinstance(event, NodeRunFailedEvent): + elif isinstance(event, NodeRunExceptionEvent): self._publish_event( - QueueNodeFailedEvent( + QueueNodeExceptionEvent( node_execution_id=event.id, node_id=event.node_id, node_type=event.node_type, @@ -271,7 +349,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, execution_metadata=event.route_node_state.node_run_result.metadata diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 31c3a996e19286..16dc91bb777a9b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL -from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity @@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: AppConfig + app_config: Any file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15543638fc0020..a93e533ff45d26 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,8 +1,9 @@ +from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.node_entities import NodeRunMetadataKey @@ -25,12 +26,14 @@ class QueueEvent(StrEnum): WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" + WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" ITERATION_START = "iteration_start" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" NODE_STARTED = "node_started" NODE_SUCCEEDED = "node_succeeded" NODE_FAILED = "node_failed" + NODE_EXCEPTION = "node_exception" RETRIEVER_RESOURCES = "retriever_resources" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" @@ -41,6 +44,7 @@ class QueueEvent(StrEnum): ERROR = "error" PING = "ping" STOP = "stop" + RETRY = "retry" class AppQueueEvent(BaseModel): @@ -82,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): @@ -113,18 +117,6 @@ class QueueIterationNextEvent(AppQueueEvent): output: Optional[Any] = None # output for the current iteration duration: Optional[float] = None - @field_validator("output", mode="before") - @classmethod - def set_output(cls, v): - """ - Set output - """ - if v is None: - return None - if isinstance(v, int | float | str | bool | dict | list): - return v - raise ValueError("output must be a valid type") - class QueueIterationCompletedEvent(AppQueueEvent): """ @@ -148,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None steps: int = 0 error: Optional[str] = None @@ -249,6 +241,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str + exceptions_count: int + + +class QueueWorkflowPartialSuccessEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED + exceptions_count: int + outputs: Optional[dict[str, Any]] = None class QueueNodeStartedEvent(AppQueueEvent): @@ -302,16 +305,30 @@ class QueueNodeSucceededEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" iteration_duration_map: Optional[dict[str, float]] = None +class QueueNodeRetryEvent(QueueNodeStartedEvent): + """QueueNodeRetryEvent entity""" + + event: QueueEvent = QueueEvent.RETRY + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + retry_index: int # retry index + + class QueueNodeInIterationFailedEvent(AppQueueEvent): """ QueueNodeInIterationFailedEvent entity @@ -335,10 +352,41 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + + +class QueueNodeExceptionEvent(AppQueueEvent): + """ + QueueNodeExceptionEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_EXCEPTION + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str @@ -366,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 03cc6941a84623..5e845eba2da1d3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -52,6 +52,7 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + NODE_RETRY = "node_retry" PARALLEL_BRANCH_STARTED = "parallel_branch_started" PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" @@ -69,7 +70,7 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) @@ -213,6 +214,7 @@ class Data(BaseModel): created_by: Optional[dict] = None created_at: int finished_at: int + exceptions_count: Optional[int] = 0 files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED @@ -341,6 +343,75 @@ def to_ignore_detail_dict(self): } +class NodeRetryStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + retry_index: int = 0 + + event: StreamEvent = StreamEvent.NODE_RETRY + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + "retry_index": self.data.retry_index, + }, + } + + class ParallelBranchStartStreamResponse(StreamResponse): """ ParallelBranchStartStreamResponse entity @@ -403,8 +474,8 @@ class Data(BaseModel): title: str created_at: int extras: dict = {} - metadata: dict = {} - inputs: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -455,15 +526,15 @@ class Data(BaseModel): node_id: str node_type: str title: str - outputs: Optional[dict] = None + outputs: Optional[Mapping] = None created_at: int extras: Optional[dict] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float total_tokens: int - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping] = None finished_at: int steps: int parallel_id: Optional[str] = None @@ -557,7 +628,7 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 77b6bb554c65ec..83fd3debad4cf1 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -58,7 +58,7 @@ def query( query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) - if documents: + if documents and documents[0].metadata: annotation_id = documents[0].metadata["annotation_id"] score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 154a49ebda2b88..dcc2b4e55f6ae1 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -17,7 +17,7 @@ class RateLimit: _UNLIMITED_REQUEST_ID = "unlimited_request_id" _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes - _instance_dict = {} + _instance_dict: dict[str, "RateLimit"] = {} def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -110,7 +110,7 @@ def __next__(self): raise StopIteration try: return next(self.generator) - except StopIteration: + except Exception: self.close() raise diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 51d610e2cbedc6..03a81353d02625 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -62,6 +62,7 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non """ logger.debug("error: %s", event.error) e = event.error + err: Exception if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") @@ -130,6 +131,7 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), queue_manager=self._queue_manager, ) + return None def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 917649f34e769c..b9f8e7ca560ce7 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator +from threading import Thread from typing import Optional, Union, cast from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -103,7 +104,7 @@ def __init__( ) ) - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Optional[Thread] = None def process( self, @@ -123,7 +124,7 @@ def process( if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + self._conversation, self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -146,7 +147,7 @@ def _to_blocking_response( extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: extras["metadata"] = self._task_state.metadata - + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -154,7 +155,7 @@ def _to_blocking_response( id=self._message.id, mode=self._conversation.mode, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -167,7 +168,7 @@ def _to_blocking_response( mode=self._conversation.mode, conversation_id=self._conversation.id, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -177,7 +178,7 @@ def _to_blocking_response( else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise RuntimeError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -201,11 +202,11 @@ def _to_stream_response( stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -252,7 +253,7 @@ def _wrapper_process_stream_response( yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -269,13 +270,14 @@ def _process_stream_response( break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result + if event.llm_result: + self._task_state.llm_result = event.llm_result else: self._handle_stop(event) # handle output moderation output_moderation_answer = self._handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.content + cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer @@ -292,7 +294,9 @@ def _process_stream_response( if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - yield self._agent_thought_to_stream_response(event) + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): response = self._message_file_to_stream_response(event) if response: @@ -307,16 +311,18 @@ def _process_stream_response( self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: continue - self._task_state.llm_result.message.content += delta_text + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response(cast(str, delta_text), self._message.id) else: - yield self._agent_message_to_stream_response(delta_text, self._message.id) + yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -336,8 +342,14 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No llm_result = self._task_state.llm_result usage = llm_result.usage - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if not message: + raise Exception(f"Message {self._message.id} not found") + self._message = message + conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + if not conversation: + raise Exception(f"Conversation {self._conversation.id} not found") + self._conversation = conversation self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages @@ -346,7 +358,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit self._message.answer = ( - PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) @@ -374,6 +386,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and hasattr(self._application_generate_entity, "conversation_id") and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) @@ -420,7 +433,9 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + metadata=extras.get("metadata", {}), ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -440,7 +455,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :param event: agent thought event :return: """ - agent_thought: MessageAgentThought = ( + agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py new file mode 100644 index 00000000000000..e4b4168d0881e0 --- /dev/null +++ b/api/core/app/task_pipeline/exc.py @@ -0,0 +1,17 @@ +class TaskPipilineError(ValueError): + pass + + +class RecordNotFoundError(TaskPipilineError): + def __init__(self, record_name: str, record_id: str): + super().__init__(f"{record_name} with id {record_id} not found") + + +class WorkflowRunNotFoundError(RecordNotFoundError): + def __init__(self, workflow_run_id: str): + super().__init__("WorkflowRun", workflow_run_id) + + +class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): + def __init__(self, workflow_node_execution_id: str): + super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index e818a090ed7d0f..007543f6d0d1f2 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -128,7 +128,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti """ message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() - if message_file: + if message_file and message_file.url is not None: # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 57a02f8bc85eac..f581e564f224ce 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,8 +12,10 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -25,6 +27,7 @@ IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, NodeFinishStreamResponse, + NodeRetryStreamResponse, NodeStartStreamResponse, ParallelBranchFinishedStreamResponse, ParallelBranchStartStreamResponse, @@ -55,6 +58,8 @@ WorkflowRunStatus, ) +from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError + class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] @@ -88,7 +93,7 @@ def _handle_workflow_run_start(self) -> WorkflowRun: ) # handle special values - inputs = WorkflowEntry.handle_special_values(inputs) + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run with Session(db.engine, expire_on_commit=False) as session: @@ -164,6 +169,55 @@ def _handle_workflow_run_success( return workflow_run + def _handle_workflow_run_partial_success( + self, + workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + exceptions_count: int = 0, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :param conversation_id: conversation id + :return: + """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + + workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.exceptions_count = exceptions_count + db.session.commit() + db.session.refresh(workflow_run) + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + db.session.close() + + return workflow_run + def _handle_workflow_run_failed( self, workflow_run: WorkflowRun, @@ -174,6 +228,7 @@ def _handle_workflow_run_failed( error: str, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, + exceptions_count: int = 0, ) -> WorkflowRun: """ Workflow run failed @@ -193,7 +248,7 @@ def _handle_workflow_run_failed( workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - + workflow_run.exceptions_count = exceptions_count db.session.commit() running_workflow_node_executions = ( @@ -220,9 +275,9 @@ def _handle_workflow_run_failed( db.session.close() - with Session(db.engine, expire_on_commit=False) as session: - session.add(workflow_run) - session.refresh(workflow_run) + # with Session(db.engine, expire_on_commit=False) as session: + # session.add(workflow_run) + # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -318,7 +373,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -337,7 +392,11 @@ def _handle_workflow_node_execution_failed( ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, + WorkflowNodeExecution.status: ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ), WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, @@ -351,8 +410,11 @@ def _handle_workflow_node_execution_failed( db.session.commit() db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.status = ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ) workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None @@ -365,6 +427,59 @@ def _handle_workflow_node_execution_failed( return workflow_node_execution + def _handle_workflow_node_execution_retried( + self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + created_at = event.start_at + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - created_at).total_seconds() + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) + origin_metadata = { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + } + merged_metadata = ( + {**jsonable_encoder(event.execution_metadata), **origin_metadata} + if event.execution_metadata is not None + else origin_metadata + ) + execution_metadata = json.dumps(merged_metadata) + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = created_at + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.index = event.node_run_index + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + + return workflow_node_execution + ################################################# # to stream responses # ################################################# @@ -385,7 +500,7 @@ def _workflow_start_to_stream_response( id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, + inputs=dict(workflow_run.inputs_dict or {}), created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -399,6 +514,12 @@ def _workflow_finish_to_stream_response( :param workflow_run: workflow run :return: """ + # Attach WorkflowRun to an active session so "created_by_role" can be accessed. + workflow_run = db.session.merge(workflow_run) + + # Refresh to ensure any expired attributes are fully loaded + db.session.refresh(workflow_run) + created_by = None if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: created_by_account = workflow_run.created_by_account @@ -424,7 +545,7 @@ def _workflow_finish_to_stream_response( workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, status=workflow_run.status, - outputs=workflow_run.outputs_dict, + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, error=workflow_run.error, elapsed_time=workflow_run.elapsed_time, total_tokens=workflow_run.total_tokens, @@ -432,7 +553,8 @@ def _workflow_finish_to_stream_response( created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), + exceptions_count=workflow_run.exceptions_count, ), ) @@ -483,7 +605,10 @@ def _workflow_node_start_to_stream_response( def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, + event: QueueNodeSucceededEvent + | QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: @@ -525,6 +650,51 @@ def _workflow_node_finish_to_stream_response( ), ) + def _workflow_node_retry_to_stream_response( + self, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + """ + Workflow node finish to stream response. + :param event: queue node succeeded or failed event + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + ), + ) + def _workflow_parallel_branch_start_to_stream_response( self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: @@ -668,7 +838,7 @@ def _workflow_iteration_completed_to_stream_response( ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -681,9 +851,11 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping # Remove None files = [file for file in files if file] # Flatten list - files = [file for sublist in files for file in sublist] + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] - return files + # Convert to tuple to match Sequence type + return tuple(flattened_files) def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ @@ -721,6 +893,8 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any elif isinstance(value, File): return value.to_dict() + return None + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run @@ -730,7 +904,7 @@ def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() if not workflow_run: - raise Exception(f"Workflow run not found: {workflow_run_id}") + raise WorkflowRunNotFoundError(workflow_run_id) return workflow_run @@ -743,6 +917,6 @@ def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNo workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) if not workflow_node_execution: - raise Exception(f"Workflow node execution not found: {node_execution_id}") + raise WorkflowNodeExecutionNotFoundError(node_execution_id) return workflow_node_execution diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index d826edf6a0fc19..effc7eff9179ae 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -57,7 +57,7 @@ def on_tool_end( self, tool_name: str, tool_inputs: Mapping[str, Any], - tool_outputs: Sequence[ToolInvokeMessage], + tool_outputs: Sequence[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None, diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 1481578630f63b..8f8aaa93d6f986 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -40,17 +40,18 @@ def on_query(self, query: str, dataset_id: str) -> None: def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 9ed5528e43b9b8..5017835565789c 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] = [] class DefaultModelEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d1b34db2fe7172..bff5a0ec9c6be7 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): @@ -99,7 +99,8 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional continue restrict_models = quota_configuration.restrict_models - + if self.system_configuration.credentials is None: + return None copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: @@ -124,7 +125,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional return credentials - def get_system_configuration_status(self) -> SystemConfigurationStatus: + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: @@ -136,6 +137,8 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) + if current_quota_configuration is None: + return None return ( SystemConfigurationStatus.ACTIVE @@ -150,7 +153,7 @@ def is_custom_configuration_available(self) -> bool: """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + def get_custom_credentials(self, obfuscated: bool = False): """ Get custom credentials. @@ -172,7 +175,7 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: else [], ) - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -324,7 +327,7 @@ def get_custom_model_credentials( def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel, dict]: + ) -> tuple[Optional[ProviderModel], dict]: """ Validate custom model credentials. @@ -740,10 +743,10 @@ def get_provider_models( if model_type: model_types.append(model_type) else: - model_types = provider_instance.get_provider_schema().supported_model_types + model_types = list(provider_instance.get_provider_schema().supported_model_types) # Group model settings by model type and model - model_setting_map = defaultdict(dict) + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting @@ -822,54 +825,57 @@ def _get_system_provider_models( ]: # only customizable model for restrict_model in restrict_models: - copy_credentials = self.system_configuration.credentials.copy() - if restrict_model.base_model_name: - copy_credentials["base_model_name"] = restrict_model.base_model_name - - try: - custom_model_schema = provider_instance.get_model_instance( - restrict_model.model_type - ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) - except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") - continue - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, + if self.system_configuration.credentials is not None: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials["base_model_name"] = restrict_model.base_model_name + + try: + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) ) - ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION + for model in provider_models: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED + model.status = ModelStatus.QUOTA_EXCEEDED return provider_models @@ -1043,7 +1049,7 @@ def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: - return self.configurations.values() + return iter(self.configurations.values()) def get(self, key, default=None): return self.configurations.get(key, default) diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 3b186476ebe977..ad921bc2556ffe 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,7 +1,7 @@ from typing import Optional -class LLMError(Exception): +class LLMError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None @@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError): description = "Bad Request" -class ProviderTokenNotInitError(Exception): +class ProviderTokenNotInitError(ValueError): """ Custom exception raised when the provider token is not initialized. """ @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): self.description = args[0] if args else self.description -class QuotaExceededError(Exception): +class QuotaExceededError(ValueError): """ Custom exception raised when the quota for a provider has been exceeded. """ @@ -35,7 +35,7 @@ class QuotaExceededError(Exception): description = "Quota Exceeded" -class AppInvokeQuotaExceededError(Exception): +class AppInvokeQuotaExceededError(ValueError): """ Custom exception raised when the quota for an app has been exceeded. """ @@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception): description = "App Invoke Quota Exceeded" -class ModelCurrentlyNotSupportError(Exception): +class ModelCurrentlyNotSupportError(ValueError): """ Custom exception raised when the model not support """ @@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception): description = "Model Currently Not Support" -class InvokeRateLimitError(Exception): +class InvokeRateLimitError(ValueError): """Raised when the Invoke returns rate limit error.""" description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 38cebb6b6b1c36..3f4e20ec245302 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +from typing import cast + import requests from configs import dify_config @@ -5,7 +7,7 @@ class APIBasedExtensionRequestor: - timeout: (int, int) = (5, 60) + timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str) -> None: @@ -51,4 +53,4 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) ) - return response.json() + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 97dbaf2026e790..231743bf2a948c 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -38,8 +38,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: @classmethod def scan_extensions(cls): - extensions: list[ModuleExtension] = [] - position_map = {} + extensions = [] + position_map: dict[str, int] = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") @@ -58,7 +58,8 @@ def scan_extensions(cls): # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False - position = None + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 if "__builtin__" in file_names: builtin = True @@ -89,7 +90,7 @@ def scan_extensions(cls): logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue - json_data = {} + json_data: dict[str, Any] = {} if not builtin: if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 3da170455e3398..9eb9e0306b577f 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,6 @@ -from core.extension.extensible import ExtensionModule, ModuleExtension +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation @@ -10,7 +12,8 @@ class Extension: def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module.value] = module_class.scan_extensions() + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -35,7 +38,8 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) - return module_extension.extension_class + t: type = module_extension.extension_class + return t def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: module_extension = self.module_extension(module, extension_name) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 54ec97a4933a94..9989c8a09013bd 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -48,7 +48,10 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :return: the tool query result """ # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" # get api_based_extension api_based_extension = ( diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 84b94e117ff5f9..6a9703a569b308 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ -import concurrent import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional from flask import Flask, current_app @@ -17,9 +17,9 @@ def fetch( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. @@ -30,13 +30,14 @@ def fetch( :param query: the query :return: the filled inputs """ - results = {} + results: dict[str, Any] = {} + inputs = dict(inputs) with ThreadPoolExecutor() as executor: futures = {} for tool in external_data_tools: - future = executor.submit( + future: Future[tuple[str | None, str | None]] = executor.submit( self._query_external_data_tool, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore tenant_id, app_id, tool, @@ -46,9 +47,10 @@ def fetch( futures[future] = tool - for future in concurrent.futures.as_completed(futures): + for future in as_completed(futures): tool_variable, result = future.result() - results[tool_variable] = result + if tool_variable is not None: + results[tool_variable] = result inputs.update(results) return inputs @@ -59,7 +61,7 @@ def _query_external_data_tool( tenant_id: str, app_id: str, external_data_tool: ExternalDataVariableEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, ) -> tuple[Optional[str], Optional[str]]: """ diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 28721098594962..245507e17c7032 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,4 +1,5 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -23,9 +24,10 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - extension_class.validate_config(tenant_id, config) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: """ Query the external data tool. @@ -33,4 +35,4 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :param query: the query of chat app :return: the tool query result """ - return self.__extension_instance.query(inputs, query) + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 6d8086435d5b29..4a50fb85c9cca3 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,15 +1,15 @@ import base64 +from collections.abc import Mapping from configs import dify_config -from core.file import file_repository from core.helper import ssrf_proxy from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, VideoPromptMessageContent, ) -from extensions.ext_database import db from extensions.ext_storage import storage from . import helpers @@ -41,53 +41,42 @@ def to_prompt_message_content( /, *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -): - match f.type: - case FileType.IMAGE: - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": - data = _to_url(f) - else: - data = _to_base64_data_string(f) - - return ImagePromptMessageContent(data=data, detail=image_detail_config) - case FileType.AUDIO: - encoded_string = _get_encoded_string(f) - if f.extension is None: - raise ValueError("Missing file extension") - return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) - case FileType.VIDEO: - if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": - data = _to_url(f) - else: - data = _to_base64_data_string(f) - if f.extension is None: - raise ValueError("Missing file extension") - return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) - case FileType.DOCUMENT: - data = _get_encoded_string(f) - if f.mime_type is None: - raise ValueError("Missing file mime_type") - return DocumentPromptMessageContent( - encode_format="base64", - mime_type=f.mime_type, - data=data, - ) - case _: - raise ValueError(f"file type {f.type} is not supported") +) -> MultiModalPromptMessageContent: + if f.extension is None: + raise ValueError("Missing file extension") + if f.mime_type is None: + raise ValueError("Missing file mime_type") + + params = { + "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", + "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", + "format": f.extension.removeprefix("."), + "mime_type": f.mime_type, + } + if f.type == FileType.IMAGE: + params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + try: + return prompt_class_map[f.type].model_validate(params) + except KeyError: + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): - if f.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - return _download_file_content(tool_file.file_key) - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - return _download_file_content(upload_file.key) - # remote file - response = ssrf_proxy.get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content + if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): + return _download_file_content(f._storage_key) + elif f.transfer_method == FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content + raise ValueError(f"unsupported transfer method: {f.transfer_method}") def _download_file_content(path: str, /): @@ -118,21 +107,14 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - data = _download_file_content(upload_file.key) + data = _download_file_content(f._storage_key) case FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - data = _download_file_content(tool_file.file_key) + data = _download_file_content(f._storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string -def _to_base64_data_string(f: File, /): - encoded_string = _get_encoded_string(f) - return f"data:{f.mime_type};base64,{encoded_string}" - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: @@ -141,7 +123,7 @@ def _to_url(f: File, /): elif f.transfer_method == FileTransferMethod.LOCAL_FILE: if f.related_id is None: raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=f.related_id) + return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) elif f.transfer_method == FileTransferMethod.TOOL_FILE: # add sign url if f.related_id is None or f.extension is None: diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py deleted file mode 100644 index 975e1e72db0e0a..00000000000000 --- a/api/core/file/file_repository.py +++ /dev/null @@ -1,32 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from models import ToolFile, UploadFile - -from .models import File - - -def get_upload_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(UploadFile).filter( - UploadFile.id == file.related_id, - UploadFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"upload file {file.related_id} not found") - return record - - -def get_tool_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(ToolFile).filter( - ToolFile.id == file.related_id, - ToolFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"tool file {file.related_id} not found") - return record diff --git a/api/core/file/models.py b/api/core/file/models.py index 3e7e189c62cade..4b4674da095d34 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -47,6 +47,38 @@ class File(BaseModel): mime_type: Optional[str] = None size: int = -1 + # Those properties are private, should not be exposed to the outside. + _storage_key: str + + def __init__( + self, + *, + id: Optional[str] = None, + tenant_id: str, + type: FileType, + transfer_method: FileTransferMethod, + remote_url: Optional[str] = None, + related_id: Optional[str] = None, + filename: Optional[str] = None, + extension: Optional[str] = None, + mime_type: Optional[str] = None, + size: int = -1, + storage_key: str, + ): + super().__init__( + id=id, + tenant_id=tenant_id, + type=type, + transfer_method=transfer_method, + remote_url=remote_url, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + self._storage_key = storage_key + def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index a17b7be3675ab1..6fa101cf36192b 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager @@ -9,4 +9,4 @@ class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return tool_file_manager["manager"] + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ead462..15b501780e766c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -38,7 +38,7 @@ class CodeLanguage(StrEnum): class CodeExecutor: - dependencies_cache = {} + dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { @@ -103,22 +103,22 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: ) try: - response = response.json() + response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response.get("code")) != 0: - raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response = CodeExecutionResponse(**response) + response_code = CodeExecutionResponse(**response_data) - if response.data.error: - raise CodeExecutionError(response.data.error) + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) - return response.data.stdout or "" + return response_code.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index db2eb5ebb6b19a..264947b5686d0e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -1,9 +1,11 @@ +from collections.abc import Mapping + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: dict) -> str: + def format(cls, template: str, inputs: Mapping[str, str]) -> str: """ Format template :param template: template @@ -11,5 +13,4 @@ def format(cls, template: str, inputs: dict) -> str: :return: """ result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - - return result["result"] + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b7a07b21e1d784..baa792b5bc6c41 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,21 +25,28 @@ def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, st return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") - result = result.group(1) - return result + return result.group(1) @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - return json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") + if not isinstance(result, dict): + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") + return result @classmethod @abstractmethod diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 96341a1b780a80..744fce1cf99cfe 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -1,6 +1,5 @@ import base64 -from extensions.ext_database import db from libs import rsa @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + from models.engine import db if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py index 518962c1652df7..81501d2e4e23b2 100644 --- a/api/core/helper/lru_cache.py +++ b/api/core/helper/lru_cache.py @@ -4,7 +4,7 @@ class LRUCache: def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache: OrderedDict[Any, Any] = OrderedDict() self.capacity = capacity def get(self, key: Any) -> Any: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 5e274f8916869d..35349210bd53ab 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -30,7 +30,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index da0fd0031cc6dc..543444463b9f1a 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: hosting_openai_config = hosting_configuration.provider_map["openai"] + assert hosting_openai_config is not None # 2000 text per chunk length = 2000 @@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() + # FIXME, for type hint using assert or raise ValueError is better here? moderation_result = model_type_instance.invoke( - model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk ) if moderation_result is True: diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 1e2fefce88b632..9a041667e46df5 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz if existed_spec: spec = existed_spec if not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - spec = importlib.util.spec_from_file_location(module_name, py_file_path) + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore if not spec or not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'") + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") raise e @@ -57,6 +58,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") case _: - raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 566293d1250402..424983a819ec28 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,6 +24,12 @@ STATUS_FORCELIST = [429, 500, 502, 503, 504] +class MaxRetriesExceededError(ValueError): + """Raised when the maximum number of retries is exceeded.""" + + pass + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -59,12 +65,13 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): except httpx.RequestError as e: logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + if max_retries == 0: + raise retries += 1 if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - - raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") + raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index e848b46c5633ab..3b67b3f84838d3 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -33,7 +33,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_tool_parameter + return dict(cached_tool_parameter) else: return None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 94b02cf98578b1..6de5e704abf4f5 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -28,7 +28,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index b47ba67f2fa64f..f9fb7275f3624f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] = {} - moderation_config: HostedModerationConfig = None + moderation_config: Optional[HostedModerationConfig] = None def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": @@ -67,7 +67,7 @@ def init_azure_openai() -> HostingProvider: "base_model_name": "gpt-35-turbo", } - quotas = [] + quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, @@ -123,7 +123,7 @@ def init_azure_openai() -> HostingProvider: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT @@ -157,7 +157,7 @@ def init_openai(self) -> HostingProvider: @staticmethod def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT @@ -187,7 +187,7 @@ def init_anthropic() -> HostingProvider: def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -205,7 +205,7 @@ def init_minimax() -> HostingProvider: def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -223,7 +223,7 @@ def init_spark() -> HostingProvider: def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 29e161cb747284..1f0a0d0ef1dda4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -6,10 +6,10 @@ import threading import time import uuid -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -62,6 +62,8 @@ def run(self, dataset_documents: list[DatasetDocument]): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract @@ -120,6 +122,8 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -254,7 +258,7 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts: list[str] = [] total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -285,7 +289,8 @@ def indexing_estimate( for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -379,8 +384,9 @@ def _extract( # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata["document_id"] = dataset_document.id - text_doc.metadata["dataset_id"] = dataset_document.dataset_id + if text_doc.metadata is not None: + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @@ -400,6 +406,7 @@ def _get_splitter( """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule.mode == "custom": # The user-defined segmentation rule rules = json.loads(processing_rule.rules) @@ -426,9 +433,10 @@ def _get_splitter( ) else: # Automatic segmentation + automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"]) character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + chunk_size=automatic_rules["max_tokens"], + chunk_overlap=automatic_rules["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) @@ -497,8 +505,8 @@ def _split_to_documents( """ Split the text documents into nodes. """ - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -509,10 +517,11 @@ def _split_to_documents( split_documents = [] for document_node in documents: if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -529,7 +538,7 @@ def _split_to_documents( document_format_thread = threading.Thread( target=self.format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": tenant_id, "document_node": doc, "all_qa_documents": all_qa_documents, @@ -557,11 +566,12 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al qa_document = Document( page_content=result["question"], metadata=document_node.metadata.model_copy() ) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -575,7 +585,7 @@ def _split_to_documents_for_estimate( """ Split the text documents into nodes. """ - all_documents = [] + all_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -588,11 +598,11 @@ def _split_to_documents_for_estimate( for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document.page_content) - - document.metadata["doc_id"] = doc_id - document.metadata["doc_hash"] = hash + if document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -648,7 +658,7 @@ def _load( # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() if dataset.indexing_technique == "high_quality": @@ -659,7 +669,7 @@ def _load( futures.append( executor.submit( self._process_chunk, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore index_processor, chunk_documents, dataset, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3a92c8d9d22562..9fe3f68f2a8af5 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Optional +from typing import Optional, cast from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -13,6 +13,7 @@ WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -44,10 +45,13 @@ def generate_conversation_name( prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response = model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), ) - answer = response.message.content + answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" @@ -94,11 +98,16 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st prompt_messages = [UserPromptMessage(content=prompt)] try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), ) - questions = output_parser.parse(response.message.content) + questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception as e: @@ -138,11 +147,14 @@ def generate_rule_config( ) try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["prompt"] = response.message.content + rule_config["prompt"] = cast(str, response.message.content) except InvokeError as e: error = str(e) @@ -178,15 +190,18 @@ def generate_rule_config( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: try: # the first step to generate the task prompt - prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) except InvokeError as e: error = str(e) @@ -195,8 +210,10 @@ def generate_rule_config( return rule_config - rule_config["prompt"] = prompt_content.message.content + rule_config["prompt"] = cast(str, prompt_content.message.content) + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ "INPUT_TEXT": prompt_content.message.content, @@ -216,19 +233,25 @@ def generate_rule_config( statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: - parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) except InvokeError as e: error = str(e) error_step = "generate variables" try: - statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["opening_statement"] = statement_content.message.content + rule_config["opening_statement"] = cast(str, statement_content.message.content) except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -267,19 +290,22 @@ def generate_code( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - generated_code = response.message.content + generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -303,9 +329,14 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), ) - answer = response.message.content + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 1e743f1757473e..0922806ca88ce6 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserError(Exception): +class OutputParserError(ValueError): pass diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 81d08dc8854f80..003a0c85b1f12e 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -68,7 +68,7 @@ def get_history_prompt_messages( messages = list(reversed(thread_messages)) - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1986688551b601..d1e71148cd6023 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -124,17 +124,20 @@ def invoke_llm( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, + return cast( + Union[LLMResult, Generator], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ), ) def get_llm_num_tokens( @@ -151,12 +154,15 @@ def get_llm_num_tokens( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools, + ), ) def invoke_text_embedding( @@ -174,13 +180,16 @@ def invoke_text_embedding( raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - texts=texts, - user=user, - input_type=input_type, + return cast( + TextEmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + texts=texts, + user=user, + input_type=input_type, + ), ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -194,11 +203,14 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - texts=texts, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts, + ), ) def invoke_rerank( @@ -223,15 +235,18 @@ def invoke_rerank( raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user, + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), ) def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: @@ -246,12 +261,15 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - text=text, - user=user, + return cast( + bool, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + text=text, + user=user, + ), ) def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: @@ -266,12 +284,15 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - file=file, - user=user, + return cast( + str, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + file=file, + user=user, + ), ) def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: @@ -288,17 +309,20 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, + return cast( + Iterable[bytes], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: """ Round-robin invoke :param function: function to invoke @@ -309,7 +333,7 @@ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): if not self.load_balancing_manager: return function(*args, **kwargs) - last_exception = None + last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None while True: lb_config = self.load_balancing_manager.fetch_next() if not lb_config: @@ -463,7 +487,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: if real_index > max_index: real_index = 0 - config = self._load_balancing_configs[real_index] + config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index] if self.in_cooldown(config): cooldown_load_balancing_configs.append(config) @@ -507,8 +531,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) - res = cast(bool, res) + res: bool = redis_client.exists(cooldown_cache_key) return res @staticmethod diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 3b6b825244dfdc..1f21a2d3763c4a 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,8 @@ import json import logging import sys -from typing import Optional +from collections.abc import Sequence +from typing import Optional, cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,7 +21,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -76,7 +77,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -94,7 +95,7 @@ def on_new_chunk( :param stream: is stream response :param user: unique user id """ - sys.stdout.write(chunk.delta.message.content) + sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() def on_after_invoke( @@ -106,7 +107,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -147,7 +148,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index 1c73755cffd62e..c3e1351e3b6d26 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -4,6 +4,7 @@ AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, PromptMessage, PromptMessageContent, PromptMessageContentType, @@ -27,6 +28,7 @@ "LLMResultChunkDelta", "LLMUsage", "ModelPropertyKey", + "MultiModalPromptMessageContent", "PromptMessage", "PromptMessage", "PromptMessageContent", diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index f2870209bb5e00..2f682ceef578dc 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,7 +1,7 @@ from abc import ABC from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Literal, Optional +from typing import Optional from pydantic import BaseModel, Field, field_validator @@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel): """ type: PromptMessageContentType - data: str class TextPromptMessageContent(PromptMessageContent): @@ -76,21 +75,34 @@ class TextPromptMessageContent(PromptMessageContent): """ type: PromptMessageContentType = PromptMessageContentType.TEXT + data: str + + +class MultiModalPromptMessageContent(PromptMessageContent): + """ + Model class for multi-modal prompt message content. + """ + + type: PromptMessageContentType + format: str = Field(default=..., description="the format of multi-modal file") + base64_data: str = Field(default="", description="the base64 data of multi-modal file") + url: str = Field(default="", description="the url of multi-modal file") + mime_type: str = Field(default=..., description="the mime type of multi-modal file") + @property + def data(self): + return self.url or f"data:{self.mime_type};base64,{self.base64_data}" -class VideoPromptMessageContent(PromptMessageContent): + +class VideoPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.VIDEO - data: str = Field(..., description="Base64 encoded video data") - format: str = Field(..., description="Video format") -class AudioPromptMessageContent(PromptMessageContent): +class AudioPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.AUDIO - data: str = Field(..., description="Base64 encoded audio data") - format: str = Field(..., description="Audio format") -class ImagePromptMessageContent(PromptMessageContent): +class ImagePromptMessageContent(MultiModalPromptMessageContent): """ Model class for image prompt message content. """ @@ -103,11 +115,8 @@ class DETAIL(StrEnum): detail: DETAIL = DETAIL.LOW -class DocumentPromptMessageContent(PromptMessageContent): +class DocumentPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.DOCUMENT - encode_format: Literal["base64"] - mime_type: str - data: str class PromptMessage(ABC, BaseModel): diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index edfb19c7d07d4c..76754253611568 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,7 +1,7 @@ from typing import Optional -class InvokeError(Exception): +class InvokeError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 7fcd2133f9f8d1..16bebcc67db062 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -1,4 +1,4 @@ -class CredentialsValidateFailedError(Exception): +class CredentialsValidateFailedError(ValueError): """ Credentials validate failed error """ diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 79a1d28ebe637e..e2b95603379348 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import os from abc import ABC, abstractmethod -from collections.abc import Mapping from typing import Optional from pydantic import ConfigDict @@ -36,7 +35,7 @@ class AIModel(ABC): model_config = ConfigDict(protected_namespaces=()) @abstractmethod - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -214,7 +213,7 @@ def predefined_models(self) -> list[AIModelEntity]: return model_schemas - def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -236,9 +235,7 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials( - self, model: str, credentials: Mapping - ) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -248,7 +245,7 @@ def get_customizable_model_schema_from_credentials( """ return self._get_customizable_model_schema(model, credentials) - def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -301,7 +298,7 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op return schema - def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 8faeffa872b40f..402a30376b7546 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -291,12 +291,12 @@ def _code_block_mode_stream_processor( content = piece.delta.message.content piece.delta.message.content = "" yield piece - piece = content + content_piece = content else: yield piece continue new_piece: str = "" - for char in piece: + for char in content_piece: char = str(char) if state == "normal": if char == "`": @@ -350,7 +350,7 @@ def _code_block_mode_stream_processor_with_backtick( piece.delta.message.content = "" # Yield a piece with cleared content before processing it to maintain the generator structure yield piece - piece = content + content_piece = content else: # Yield pieces without content directly yield piece @@ -360,7 +360,7 @@ def _code_block_mode_stream_processor_with_backtick( continue new_piece: str = "" - for char in piece: + for char in content_piece: if state == "search_start": if char == "`": backtick_count += 1 @@ -535,7 +535,7 @@ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRu return [] - def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: + def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: """ Get model mode diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 4374093de4ab38..36e3e7bd557163 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -104,9 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) + # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later model_class = next( filter( - lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore get_subclasses_from_module(mod, AIModel), ), None, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 2d38fba955fb86..33135129082b1d 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -89,7 +89,8 @@ def _get_context_size(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + return content_size return 1000 @@ -104,6 +105,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 5fe6dda6ad5d79..6dab0aaf2d41e7 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -2,9 +2,9 @@ from threading import Lock from typing import Any -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer +from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore -_tokenizer = None +_tokenizer: Any = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index b394ea4e9d22fe..6ce316b137abb4 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -127,7 +127,8 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: raise ValueError("this model does not support audio type") - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + return audio_type def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -138,8 +139,9 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: raise ValueError("this model does not support word limit") + world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + return world_limit def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -150,8 +152,9 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: raise ValueError("this model does not support max workers") + workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + return workers_limit @staticmethod def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 3faf5abbe87f58..c0ea8c6325c845 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,5 +1,4 @@ import base64 -import io import json from collections.abc import Generator, Sequence from typing import Optional, Union, cast @@ -18,7 +17,6 @@ ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout -from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities import ( @@ -498,22 +496,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) - if not message_content.data.startswith("data:"): + if not message_content.base64_data: # fetch image data from url try: - image_content = requests.get(message_content.data).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" + image_content = requests.get(message_content.url).content base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError( f"Failed to fetch image data from url {message_content.data}, {ex}" ) else: - data_split = message_content.data.split(";base64,") - mime_type = data_split[0].replace("data:", "") - base64_data = data_split[1] + base64_data = message_content.base64_data + mime_type = message_content.mime_type if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " @@ -534,9 +529,9 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> sub_message_dict = { "type": "document", "source": { - "type": message_content.encode_format, + "type": "base64", "media_type": message_content.mime_type, - "data": message_content.data, + "data": message_content.base64_data, }, } sub_messages.append(sub_message_dict) diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 4cf58275d79fe3..3bd6375aa9117b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -819,6 +819,82 @@ class AzureBaseModel(BaseModel): ), ), ), + AzureBaseModel( + base_model_name="gpt-4o-2024-11-20", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name="top_p", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name="presence_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name="frequency_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=16384), + ParameterRule( + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", + help=I18nObject( + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", + ), + required=False, + ), + ], + pricing=PriceConfig( + input=5.00, + output=15.00, + unit=0.000001, + currency="USD", + ), + ), + ), AzureBaseModel( base_model_name="gpt-4-turbo", entity=AIModelEntity( diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 1ef5e83abca6de..a6ae47b28e5906 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -86,6 +86,9 @@ model_credential_schema: - label: en_US: '2024-06-01' value: '2024-06-01' + - label: + en_US: '2024-10-21' + value: '2024-10-21' placeholder: zh_Hans: 在此选择您的 API 版本 en_US: Select your API Version here @@ -168,6 +171,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4o-2024-11-20 + value: gpt-4o-2024-11-20 + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-turbo value: gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index a2b14cf3dbe6d4..4aa09e61fd3599 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -64,10 +64,12 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index c45ce87ea76838..69d2cfaded453f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -92,7 +92,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 173b9d250c1743..6d50ba9163984f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -114,6 +114,8 @@ def _process_sentence(self, sentence: str, model: str, voice, credentials: dict) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 91a14bf1009006..e8d18cfff1ff06 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -10,6 +10,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, ToolPromptMessage, @@ -105,7 +106,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: - raise ValueError("User message content must be str") + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_dict = {"role": "user", "content": message_content.data} + elif message_content.type == PromptMessageContentType.IMAGE: + raise ValueError("Content object type not support image_url") elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py new file mode 100644 index 00000000000000..2ad37cef3b38f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -0,0 +1,29 @@ +from collections.abc import Mapping + +import boto3 +from botocore.config import Config + +from core.model_runtime.errors.invoke import InvokeBadRequestError + + +def get_bedrock_client(service_name: str, credentials: Mapping[str, str]): + region_name = credentials.get("aws_region") + if not region_name: + raise InvokeBadRequestError("aws_region is required") + client_config = Config(region_name=region_name) + aws_access_key_id = credentials.get("aws_access_key_id") + aws_secret_access_key = credentials.get("aws_secret_access_key") + + if aws_access_key_id and aws_secret_access_key: + # use aksk to call bedrock + client = boto3.client( + service_name=service_name, + config=client_config, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + else: + # use iam without aksk to call + client = boto3.client(service_name=service_name, config=client_config) + + return client diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml index 5aaf50473e680f..41a7f96d80c7fc 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml @@ -6,6 +6,7 @@ features: - agent-thought - tool-call - stream-tool-call + - vision model_properties: mode: chat context_size: 300000 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml index 75e53e74a94925..53cf560610f00a 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml @@ -6,6 +6,7 @@ features: - agent-thought - tool-call - stream-tool-call + - vision model_properties: mode: chat context_size: 300000 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index e6e8a765ee9e05..29bd673d576fc9 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -6,9 +6,9 @@ from typing import Optional, Union, cast # 3rd import -import boto3 -from botocore.config import Config -from botocore.exceptions import ( +import boto3 # type: ignore +from botocore.config import Config # type: ignore +from botocore.exceptions import ( # type: ignore ClientError, EndpointConnectionError, NoRegionError, @@ -40,6 +40,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. @@ -173,13 +174,7 @@ def _generate_with_converse( :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"], - ) - + bedrock_client = get_bedrock_client("bedrock-runtime", credentials) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml index 594f304617f478..522353489d7e20 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml @@ -6,6 +6,7 @@ features: - agent-thought - tool-call - stream-tool-call + - vision model_properties: mode: chat context_size: 300000 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml index dfb3e5f21059e3..20e9d233e1fa7c 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml @@ -6,6 +6,7 @@ features: - agent-thought - tool-call - stream-tool-call + - vision model_properties: mode: chat context_size: 300000 diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index 397f65e8c960c8..9da23ba1b0f08f 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -1,8 +1,5 @@ from typing import Optional -import boto3 -from botocore.config import Config - from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -14,6 +11,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client class BedrockRerankModel(RerankModel): @@ -48,13 +46,7 @@ def _invoke( return RerankResult(model=model, docs=docs) # initialize client - client_config = Config(region_name=credentials["aws_region"]) - bedrock_runtime = boto3.client( - service_name="bedrock-agent-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id", ""), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials) queries = [{"type": "TEXT", "textQuery": {"text": query}}] text_sources = [] for text in docs: @@ -70,7 +62,10 @@ def _invoke( } ) modelId = model - region = credentials["aws_region"] + region = credentials.get("aws_region") + # region is a required field + if not region: + raise InvokeBadRequestError("aws_region is required in credentials") model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}" rerankingConfiguration = { "type": "BEDROCK_RERANKING_MODEL", diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2f998d8bdaee90..5505797f7658b3 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -3,8 +3,6 @@ import time from typing import Optional -import boto3 -from botocore.config import Config from botocore.exceptions import ( ClientError, EndpointConnectionError, @@ -25,6 +23,7 @@ InvokeServerUnavailableError, ) from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) @@ -48,14 +47,7 @@ def _invoke( :param input_type: input type :return: embeddings result """ - client_config = Config(region_name=credentials["aws_region"]) - - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials) embeddings = [] token_usage = 0 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml b/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml index 4dd58fc1708b0a..2f6f80ab74ad2f 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml +++ b/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml @@ -2,3 +2,4 @@ - rerank-english-v3.0 - rerank-multilingual-v2.0 - rerank-multilingual-v3.0 +- rerank-v3.5 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank-v3.5.yaml b/api/core/model_runtime/model_providers/cohere/rerank/rerank-v3.5.yaml new file mode 100644 index 00000000000000..5de71fb9892dce --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank-v3.5.yaml @@ -0,0 +1,4 @@ +model: rerank-v3.5 +model_type: rerank +model_properties: + context_size: 5120 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index aba8fedbc097e5..3a0a241f7ea0c0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) @@ -62,7 +62,7 @@ def _invoke( # format document rerank_document = RerankDocument( index=result.index, - text=result.document.text, + text=result.document.text if result.document else "", score=result.relevance_score, ) diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 5fd4d637be7643..9e4df2706080f9 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -88,7 +88,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index 610dc7b4589e9d..0a81f0c094a637 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -24,6 +24,9 @@ def _invoke( user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) + # {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}} + if "response_format" in model_parameters: + model_parameters["response_format"] = {"type": model_parameters.get("response_format")} return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) def validate_credentials(self, model: str, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py index 378ced3a4019ba..38d0a9dfbcadee 100644 --- a/api/core/model_runtime/model_providers/fireworks/_common.py +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from core.model_runtime.errors.invoke import ( @@ -13,7 +11,7 @@ class _CommonFireworks: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index c745a7e978f4be..4c036283893fcc 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -1,5 +1,4 @@ import time -from collections.abc import Mapping from typing import Optional, Union import numpy as np @@ -93,7 +92,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ return sum(self._get_num_tokens_by_gpt2(text) for text in texts) - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py index 0750f3b75d0542..ad6600faf7bc15 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/_common.py +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2-8B.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2-8B.yaml new file mode 100644 index 00000000000000..d288c3dd3948b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2-8B.yaml @@ -0,0 +1,93 @@ +model: InternVL2-8B +label: + en_US: InternVL2-8B +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2.5-26B.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2.5-26B.yaml new file mode 100644 index 00000000000000..b2dee88c0285e7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/InternVL2.5-26B.yaml @@ -0,0 +1,93 @@ +model: InternVL2.5-26B +label: + en_US: InternVL2.5-26B +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml index 13c31ad02b2bc4..c942cda3b26640 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml @@ -6,3 +6,5 @@ - deepseek-coder-33B-instruct-chat - deepseek-coder-33B-instruct-completions - codegeex4-all-9b +- InternVL2.5-26B +- InternVL2-8B diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py index 0c253a4a0ab036..68aaad2e3f1675 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py @@ -29,18 +29,26 @@ def _invoke( user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model, model_parameters) - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke( + GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials, None) - super().validate_credentials(model, credentials) + self._add_custom_parameters(credentials, model, None) + super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials) - def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None: + def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None: if model is None: model = "Qwen2-72B-Instruct" - model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model) - credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/" + credentials["endpoint_url"] = "https://ai.gitee.com/v1" if model.endswith("completions"): credentials["mode"] = LLMMode.COMPLETION.value else: diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 832ba927406c4c..737d3d5c931221 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import httpx @@ -51,7 +51,7 @@ def _invoke( base_url = base_url.removesuffix("/") try: - body = {"model": model, "query": query, "documents": docs} + body: dict[str, Any] = {"model": model, "query": query, "documents": docs} if top_n is not None: body["top_n"] = top_n response = httpx.post( diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py index b833c5652c650a..a1fa89c5b34af6 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -24,7 +24,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) @staticmethod - def _add_custom_parameters(credentials: dict, model: str) -> None: + def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None: if model is None: model = "bge-m3" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py index 36dcea405d0974..dc91257daf9d4e 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import requests @@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): Model class for OpenAI text2speech model. """ + # FIXME this Any return will be better type def _invoke( self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ) -> any: + ) -> Any: """ _invoke text2speech model @@ -47,7 +48,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + # FIXME this Any return will be better type + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model :param model: model name diff --git a/api/core/model_runtime/model_providers/google/llm/_position.yaml b/api/core/model_runtime/model_providers/google/llm/_position.yaml index ab3081db38fc37..4ad0670e119af6 100644 --- a/api/core/model_runtime/model_providers/google/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/google/llm/_position.yaml @@ -1,3 +1,5 @@ +- gemini-2.0-flash-exp +- gemini-2.0-flash-thinking-exp-1219 - gemini-1.5-pro - gemini-1.5-pro-latest - gemini-1.5-pro-001 @@ -11,6 +13,8 @@ - gemini-1.5-flash-exp-0827 - gemini-1.5-flash-8b-exp-0827 - gemini-1.5-flash-8b-exp-0924 +- gemini-exp-1206 +- gemini-exp-1121 - gemini-exp-1114 - gemini-pro - gemini-pro-vision diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml index 43f4e4787d2e07..86bba2154a527c 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml index 7b9add6af16ebd..9ad57a19339515 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml index d6de82012ef2d9..72205f15a8760f 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml index 23b8d318fc14bc..1193e60669e2e2 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml index 9762706cd7666c..7eba1f3d4de1b8 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml index b9739d068e9907..b8c50241581670 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml index d8ab4efc918ad5..ea0c42dda88457 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml index 05184823e4ca27..16df30857c6761 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml index 548fe6ddb22d80..717d9481b91953 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml index defab26acf4d8d..bf9704f0d54879 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml index 9cbc889f1776aa..714ff35f3443f3 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml index e5aefcdb990aa7..bbca2ba3852869 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml index 00bd3e8d99db50..ae127fb4e2dea0 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml new file mode 100644 index 00000000000000..966617e9020d57 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml @@ -0,0 +1,41 @@ +model: gemini-2.0-flash-exp +label: + en_US: Gemini 2.0 Flash Exp +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call + - document + - video + - audio +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml new file mode 100644 index 00000000000000..dfcf8fd050ef06 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml @@ -0,0 +1,39 @@ +model: gemini-2.0-flash-thinking-exp-1219 +label: + en_US: Gemini 2.0 Flash Thinking Exp 1219 +model_type: llm +features: + - agent-thought + - vision + - document + - video + - audio +model_properties: + mode: chat + context_size: 32767 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml index 0515e706c2c79a..bd49b476938eee 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml @@ -8,6 +8,8 @@ features: - tool-call - stream-tool-call - document + - video + - audio model_properties: mode: chat context_size: 32767 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml index 9ca4f6e6756348..8e3f218df41971 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml @@ -7,6 +7,9 @@ features: - vision - tool-call - stream-tool-call + - document + - video + - audio model_properties: mode: chat context_size: 32767 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml new file mode 100644 index 00000000000000..7a7c361c43e18f --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml @@ -0,0 +1,41 @@ +model: gemini-exp-1206 +label: + en_US: Gemini exp 1206 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call + - document + - video + - audio +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml b/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml index 0b29814289d8b5..f6d90d52ec8820 100644 --- a/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml +++ b/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml @@ -7,6 +7,9 @@ features: - vision - tool-call - stream-tool-call + - document + - video + - audio model_properties: mode: chat context_size: 32767 diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index c19e860d2e4b4b..98273f60a41190 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,27 +1,27 @@ import base64 -import io import json +import os +import tempfile +import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Optional, Union import google.ai.generativelanguage as glm -import google.generativeai as genai +import google.generativeai as genai # type: ignore import requests from google.api_core import exceptions -from google.generativeai.client import _ClientManager -from google.generativeai.types import ContentType, GenerateContentResponse +from google.generativeai.types import ContentType, File, GenerateContentResponse from google.generativeai.types.content_types import to_part -from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - DocumentPromptMessageContent, - ImagePromptMessageContent, PromptMessage, + PromptMessageContent, PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -35,21 +35,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - -GOOGLE_AVAILABLE_MIMETYPE = [ - "application/pdf", - "application/x-javascript", - "text/javascript", - "application/x-python", - "text/x-python", - "text/plain", - "text/html", - "text/css", - "text/md", - "text/csv", - "text/xml", - "text/rtf", -] +from extensions.ext_redis import redis_client class GoogleLargeLanguageModel(LargeLanguageModel): @@ -158,7 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: @@ -201,30 +187,24 @@ def _generate( if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel(model_name=model) + genai.configure(api_key=credentials["google_api_key"]) history = [] + system_instruction = None + + for msg in prompt_messages: # makes message roles strictly alternating + content = self._format_message_to_glm_content(msg) + if history and history[-1]["role"] == content["role"]: + history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] + else: + history.append(content) - # hack for gemini-pro-vision, which currently does not support multi-turn chat - if model == "gemini-pro-vision": - last_msg = prompt_messages[-1] - content = self._format_message_to_glm_content(last_msg) - history.append(content) - else: - for msg in prompt_messages: # makes message roles strictly alternating - content = self._format_message_to_glm_content(msg) - if history and history[-1]["role"] == content["role"]: - history[-1]["parts"].extend(content["parts"]) - else: - history.append(content) - - # Create a new ClientManager with tenant's API key - new_client_manager = _ClientManager() - new_client_manager.configure(api_key=credentials["google_api_key"]) - new_custom_client = new_client_manager.make_client("generative") - - google_model._client = new_custom_client + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -317,8 +297,12 @@ def _handle_generate_stream_response( ) else: # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if hasattr(response, "usage_metadata") and response.usage_metadata: + prompt_tokens = response.usage_metadata.prompt_token_count + completion_tokens = response.usage_metadata.candidates_token_count + else: + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -346,7 +330,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) + content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -359,6 +343,40 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: return message_text + def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File: + key = f"{message_content.type.value}:{hash(message_content.data)}" + if redis_client.exists(key): + try: + return genai.get_file(redis_client.get(key).decode()) + except: + pass + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + if message_content.base64_data: + file_content = base64.b64decode(message_content.base64_data) + temp_file.write(file_content) + else: + try: + response = requests.get(message_content.url) + response.raise_for_status() + temp_file.write(response.content) + except Exception as ex: + raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}") + temp_file.flush() + + file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type) + while file.state.name == "PROCESSING": + time.sleep(5) + file = genai.get_file(file.name) + # google will delete your upload files in 2 days. + redis_client.setex(key, 47 * 60 * 60, file.name) + + try: + os.unlink(temp_file.name) + except PermissionError: + # windows may raise permission error + pass + return file + def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: """ Format a single message into glm.Content for Google API @@ -374,28 +392,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: for c in message.content: if c.type == PromptMessageContentType.TEXT: glm_content["parts"].append(to_part(c.data)) - elif c.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, c) - if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(",", 1) - mime_type = metadata.split(";", 1)[0].split(":")[1] - else: - # fetch image data from url - try: - image_content = requests.get(message_content.data).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode("utf-8") - except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} - glm_content["parts"].append(blob) - elif c.type == PromptMessageContentType.DOCUMENT: - message_content = cast(DocumentPromptMessageContent, c) - if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE: - raise ValueError(f"Unsupported mime type {message_content.mime_type}") - blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}} - glm_content["parts"].append(blob) + else: + glm_content["parts"].append(self._upload_file_content_to_google(c)) return glm_content elif isinstance(message, AssistantPromptMessage): @@ -413,7 +411,10 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: ) return glm_content elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", diff --git a/api/core/model_runtime/model_providers/groq/llm/_position.yaml b/api/core/model_runtime/model_providers/groq/llm/_position.yaml index 0613b19f87ee5e..279c1bcbe5ae92 100644 --- a/api/core/model_runtime/model_providers/groq/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/_position.yaml @@ -1,4 +1,5 @@ - llama-3.1-405b-reasoning +- llama-3.3-70b-versatile - llama-3.1-70b-versatile - llama-3.1-8b-instant - llama3-70b-8192 diff --git a/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml b/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml new file mode 100644 index 00000000000000..02f84e95f6e348 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml @@ -0,0 +1,25 @@ +model: gemma-7b-it +label: + zh_Hans: Gemma 7B Instruction Tuned + en_US: Gemma 7B Instruction Tuned +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml b/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml new file mode 100644 index 00000000000000..dad496f668ab94 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml @@ -0,0 +1,25 @@ +model: gemma2-9b-it +label: + zh_Hans: Gemma 2 9B Instruction Tuned + en_US: Gemma 2 9B Instruction Tuned +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml index ab5f6ab05efe31..01323a1b8a74f4 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml @@ -1,7 +1,8 @@ model: llama-3.1-70b-versatile +deprecated: true label: - zh_Hans: Llama-3.1-70b-versatile - en_US: Llama-3.1-70b-versatile + zh_Hans: Llama-3.1-70b-versatile (DEPRECATED) + en_US: Llama-3.1-70b-versatile (DEPRECATED) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml index 019d45372361d3..3f30d81ae4e26c 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml @@ -1,4 +1,5 @@ model: llama-3.2-11b-text-preview +deprecated: true label: zh_Hans: Llama 3.2 11B Text (Preview) en_US: Llama 3.2 11B Text (Preview) diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml index 3b34e7c07996bd..0391a7c890cec4 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml @@ -1,4 +1,5 @@ model: llama-3.2-90b-text-preview +depraceted: true label: zh_Hans: Llama 3.2 90B Text (Preview) en_US: Llama 3.2 90B Text (Preview) diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml new file mode 100644 index 00000000000000..bda9ec530a65c8 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml @@ -0,0 +1,25 @@ +model: llama-3.3-70b-specdec +label: + zh_Hans: Llama 3.3 70B Specdec + en_US: Llama 3.3 70B Specdec +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32768 +pricing: + input: "0.05" + output: "0.1" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml new file mode 100644 index 00000000000000..eb609f4db79df1 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml @@ -0,0 +1,25 @@ +model: llama-3.3-70b-versatile +label: + zh_Hans: Llama 3.3 70B Versatile + en_US: Llama 3.3 70B Versatile +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32768 +pricing: + input: "0.05" + output: "0.1" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml new file mode 100644 index 00000000000000..32ccbf1f4db29b --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml @@ -0,0 +1,25 @@ +model: llama3-groq-70b-8192-tool-use-preview +label: + zh_Hans: Llama3-groq-70b-8192-tool-use (PREVIEW) + en_US: Llama3-groq-70b-8192-tool-use (PREVIEW) +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.08' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 3c4020b6eedf24..d8a09265e21059 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,4 +1,4 @@ -from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 9d29237fdde573..cdb4103cd83712 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.hf_api import HfApi # type: ignore +from huggingface_hub.utils import BadRequestError # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 8278d1e64def89..4ca5379405f4e6 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import numpy as np import requests -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import HfApi, InferenceClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 284429b7417e32..a8a13313db946d 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -157,7 +157,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: headers["Authorization"] = f"Bearer {api_key}" extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) - print(extra_args) if extra_args.model_type != "embedding": raise CredentialsValidateFailedError("Current model is not a embedding model") diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml index 812b51ddcd176a..e0d95a830c6c37 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml @@ -3,8 +3,8 @@ label: zh_Hans: 腾讯混元 en_US: Hunyuan description: - en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite. - zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。 + en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall and hunyuan-lite. + zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall 和 hunyuan-lite。 icon_small: en_US: icon_s_en.png icon_large: diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml index f494984443cb42..6f589b3094a99b 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml @@ -4,3 +4,10 @@ - hunyuan-pro - hunyuan-turbo - hunyuan-vision +- hunyuan-role +- hunyuan-large +- hunyuan-large-role +- hunyuan-large-longcontext +- hunyuan-turbo-latest +- hunyuan-turbo-vision +- hunyuan-functioncall diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-functioncall.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-functioncall.yaml new file mode 100644 index 00000000000000..eb8656917c1416 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-functioncall.yaml @@ -0,0 +1,38 @@ +model: hunyuan-functioncall +label: + zh_Hans: hunyuan-functioncall + en_US: hunyuan-functioncall +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.004' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-longcontext.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-longcontext.yaml new file mode 100644 index 00000000000000..c39724a3a9eceb --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-longcontext.yaml @@ -0,0 +1,38 @@ +model: hunyuan-large-longcontext +label: + zh_Hans: hunyuan-large-longcontext + en_US: hunyuan-large-longcontext +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 134000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 134000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.006' + output: '0.018' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-role.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-role.yaml new file mode 100644 index 00000000000000..1b40b35ed5d6d1 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large-role.yaml @@ -0,0 +1,38 @@ +model: hunyuan-large-role +label: + zh_Hans: hunyuan-large-role + en_US: hunyuan-large-role +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.004' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large.yaml new file mode 100644 index 00000000000000..87dc104e116e8d --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-large.yaml @@ -0,0 +1,38 @@ +model: hunyuan-large +label: + zh_Hans: hunyuan-large + en_US: hunyuan-large +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-role.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-role.yaml new file mode 100644 index 00000000000000..0f6d2c5c440fda --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-role.yaml @@ -0,0 +1,38 @@ +model: hunyuan-role +label: + zh_Hans: hunyuan-role + en_US: hunyuan-role +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.004' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-latest.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-latest.yaml new file mode 100644 index 00000000000000..adfa3a4c1b8733 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-latest.yaml @@ -0,0 +1,38 @@ +model: hunyuan-turbo-latest +label: + zh_Hans: hunyuan-turbo-latest + en_US: hunyuan-turbo-latest +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.015' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-vision.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-vision.yaml new file mode 100644 index 00000000000000..5b9b17cc506a1d --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo-vision.yaml @@ -0,0 +1,39 @@ +model: hunyuan-turbo-vision +label: + zh_Hans: hunyuan-turbo-vision + en_US: hunyuan-turbo-vision +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.08' + output: '0.08' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2014de8516bc11..2dd45f065d5e26 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import cast -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -305,7 +305,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, ToolPromptMessage): message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = content if isinstance(content, str) else "" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index b6d857cb37cba0..856cda90d35a22 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -3,11 +3,11 @@ import time from typing import Optional -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index d80cbfa83d6425..1fc0f8c028ba92 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,11 +1,11 @@ from os.path import abspath, dirname, join from threading import Lock -from transformers import AutoTokenizer +from transformers import AutoTokenizer # type: ignore class JinaTokenizer: - _tokenizer = None + _tokenizer: AutoTokenizer | None = None _lock = Lock() @classmethod diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 88cc0e8e0f32d0..357631b2dba0b9 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -40,7 +40,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -117,19 +117,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -139,10 +139,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] @@ -162,5 +162,5 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator continue for choice in choices: - message = choice["delta"] - yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) + message_choice = choice["delta"] + yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 8b8fdbb6bdf558..284b61829f9729 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -41,7 +41,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -122,19 +122,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -144,10 +144,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index ce7c00f7dab84a..ca9b243c9262b3 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -35,6 +35,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { "abab7-chat-preview": MinimaxChatCompletionPro, + "abab6.5t-chat": MinimaxChatCompletionPro, "abab6.5s-chat": MinimaxChatCompletionPro, "abab6.5-chat": MinimaxChatCompletionPro, "abab6-chat": MinimaxChatCompletionPro, diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 88ebe5e2e00e7a..c248db374a2504 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -11,9 +11,9 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: dict[str, int] | None = None stop_reason: str = "" - function_call: dict[str, Any] = None + function_call: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml index bdb06b7fff6376..5702797ac447c8 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml @@ -1,3 +1,5 @@ +- pixtral-large-latest +- pixtral-large-2411 - pixtral-12b-2409 - codestral-latest - mistral-embed diff --git a/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml index 0b002b49cac8e0..9eb663bc31350b 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - vision model_properties: mode: chat context_size: 128000 @@ -21,7 +22,7 @@ parameter_rules: max: 1 - name: max_tokens use_template: max_tokens - default: 1024 + default: 8192 min: 1 max: 8192 - name: safe_prompt diff --git a/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-2411.yaml b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-2411.yaml new file mode 100644 index 00000000000000..606c9aa3319a8b --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-2411.yaml @@ -0,0 +1,52 @@ +model: pixtral-large-2411 +label: + zh_Hans: pixtral-large-2411 + en_US: pixtral-large-2411 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-latest.yaml new file mode 100644 index 00000000000000..4f0ed5ae5d77f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-large-latest.yaml @@ -0,0 +1,52 @@ +model: pixtral-large-latest +label: + zh_Hans: pixtral-large-latest + en_US: pixtral-large-latest +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 56a707333c40e9..8a4c19d4d8f71b 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -2,8 +2,8 @@ from functools import wraps from typing import Optional -from nomic import embed -from nomic import login as nomic_login +from nomic import embed # type: ignore +from nomic import login as nomic_login # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 1e1fc5b3ea89aa..9f676573fc2ece 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -5,8 +5,8 @@ from collections.abc import Generator from typing import Optional, Union -import oci -from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse +import oci # type: ignore +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 50fa63768c241b..5a428c9fed0466 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np -import oci +import oci # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 094a6746454f4d..3ae728d4b36985 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -181,9 +181,11 @@ def _generate( # prepare the payload for a simple ping to the model data = {"model": model, "stream": stream} - if "format" in model_parameters: - data["format"] = model_parameters["format"] - del model_parameters["format"] + if format_schema := model_parameters.pop("format", None): + try: + data["format"] = format_schema if format_schema == "json" else json.loads(format_schema) + except json.JSONDecodeError as e: + raise InvokeBadRequestError(f"Invalid format schema: {str(e)}") if "keep_alive" in model_parameters: data["keep_alive"] = model_parameters["keep_alive"] @@ -733,12 +735,12 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ParameterRule( name="format", label=I18nObject(en_US="Format", zh_Hans="返回格式"), - type=ParameterType.STRING, + type=ParameterType.TEXT, + default="json", help=I18nObject( - en_US="the format to return a response in. Currently the only accepted value is json.", - zh_Hans="返回响应的格式。目前唯一接受的值是json。", + en_US="the format to return a response in. Format can be `json` or a JSON schema.", + zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.", ), - options=["json"], ), ], pricing=PriceConfig( diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 83c4facc8db76c..3543fe58bb68d2 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -61,6 +61,7 @@ def _invoke( headers = {"Content-Type": "application/json"} endpoint_url = credentials.get("base_url") + assert endpoint_url is not None, "Base URL is required for Ollama API" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 2181bb4f08fd8f..ac2b3e6881c740 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonOpenAI: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 099aae38a6566a..be279d95208690 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,4 +1,7 @@ -- gpt-4o-audio-preview +- o1 +- o1-2024-12-17 +- o1-mini +- o1-mini-2024-09-12 - gpt-4 - gpt-4o - gpt-4o-2024-05-13 @@ -7,10 +10,6 @@ - chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 -- o1-preview -- o1-preview-2024-09-12 -- o1-mini -- o1-mini-2024-09-12 - gpt-4-turbo - gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview @@ -25,4 +24,7 @@ - gpt-3.5-turbo-1106 - gpt-3.5-turbo-0613 - gpt-3.5-turbo-instruct +- gpt-4o-audio-preview +- o1-preview +- o1-preview-2024-09-12 - text-davinci-003 diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml index b47449a49abc2e..19a5399a735fa9 100644 --- a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 max: 16384 - name: response_format diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml index b630d6f63075c2..2c86ec9460719b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml @@ -22,9 +22,9 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 - max: 4096 + max: 16384 - name: response_format label: zh_Hans: 回复格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index 73b7f6970076c0..cabbe9871728b3 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 max: 16384 - name: response_format diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml index ebd5ab38c3b7df..2c7c1c6eb56d11 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 max: 16384 - name: response_format diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml index 6571cd094fc36b..e707acc507cb57 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -22,9 +22,9 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 - max: 4096 + max: 16384 - name: response_format label: zh_Hans: 回复格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml index df38270f79b1c3..0c1b74c513469c 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 max: 16384 - name: response_format diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index 5e3c94fbe255c0..0d52f06339ad45 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -22,7 +22,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 max: 16384 - name: response_format diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml index 3090a9e090c2c5..a4681fe18d24f7 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml @@ -22,9 +22,9 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 512 + default: 16384 min: 1 - max: 4096 + max: 16384 - name: response_format label: zh_Hans: 回复格式 @@ -38,7 +38,7 @@ parameter_rules: - text - json_object pricing: - input: '5.00' - output: '15.00' + input: '2.50' + output: '10.00' unit: '0.000001' currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 07cb1e2d1018f9..73cd7e3c341881 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -421,7 +421,11 @@ def _generate( # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs + prompt=prompt_messages[0].content, + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, ) if stream: @@ -593,6 +597,8 @@ def _chat_generate( model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: model_parameters["response_format"] = {"type": response_format} + elif "json_schema" in model_parameters: + del model_parameters["json_schema"] extra_model_kwargs = {} @@ -920,10 +926,12 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: } sub_messages.append(sub_message_dict) elif isinstance(message_content, AudioPromptMessageContent): + data_split = message_content.data.split(";base64,") + base64_data = data_split[1] sub_message_dict = { "type": "input_audio", "input_audio": { - "data": message_content.data, + "data": base64_data, "format": message_content.format, }, } diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-2024-12-17.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-2024-12-17.yaml new file mode 100644 index 00000000000000..7acbd0e2b1166b --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-2024-12-17.yaml @@ -0,0 +1,35 @@ +model: o1-2024-12-17 +label: + en_US: o1-2024-12-17 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 50000 + min: 1 + max: 50000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1.yaml b/api/core/model_runtime/model_providers/openai/llm/o1.yaml new file mode 100644 index 00000000000000..3a84cf418ec873 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1.yaml @@ -0,0 +1,36 @@ +model: o1 +label: + zh_Hans: o1 + en_US: o1 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 50000 + min: 1 + max: 50000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 619044d808cdf6..227e4b0c152a05 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -93,7 +93,8 @@ def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + return max_characters_per_chunk return 2000 @@ -108,6 +109,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index aa6f38ce9fae5a..c546441af61d9b 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Mapping from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -9,7 +8,7 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: + def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials if validate failed, raise exception diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index bec01fe6797f52..9c8c8d5882ee4e 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -97,7 +97,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 26c090d30efd1f..8e07d56f4592de 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -478,6 +478,10 @@ def get_tool_call(tool_call_id: str): usage=usage, ) break + # handle the error here. for issue #11629 + if chunk_json.get("error") and chunk_json.get("choices") is None: + raise ValueError(chunk_json.get("error")) + if chunk_json: if u := chunk_json.get("usage"): usage = u diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index a490537e51a6ad..74229a089aa45e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -33,6 +33,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/transcriptions") diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 9da8f55d0a7ed9..b4d6c6c6ca9942 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -55,6 +55,7 @@ def _invoke( headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py index 8239c625f7ada8..53e895b0ecb376 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py @@ -44,6 +44,7 @@ def _invoke( # Construct endpoint URL endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/speech") diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2789a9250a1d35..e9509b544d9f4e 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -20,7 +20,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -165,17 +165,17 @@ def _handle_chat_stream_generate_response( if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() - if line == "[DONE]": + if line_str == "[DONE]": return try: - data = loads(line) + data = loads(line_str) except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") + raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}") output = data["outputs"] diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 7bbd31e87c595d..40ea4dc0118026 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -53,14 +53,16 @@ def _invoke( api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} @@ -142,13 +144,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") payload = {"input": "ping", "model": model} diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 915f6e0eefcd08..3e2cf2adb306db 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,4 +1,4 @@ -from replicate.exceptions import ModelError, ReplicateError +from replicate.exceptions import ModelError, ReplicateError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 3641b35dc02a39..1e7858100b0429 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError -from replicate.prediction import Prediction +from replicate import Client as ReplicateClient # type: ignore +from replicate.exceptions import ReplicateError # type: ignore +from replicate.prediction import Prediction # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index c4e9d0b9c6ceb2..aaf825388a9043 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,11 +2,11 @@ import time from typing import Optional -from replicate import Client as ReplicateClient +from replicate import Client as ReplicateClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -86,7 +86,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={"context_size": 4096, "max_chunks": 1}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1}, ) return entity @@ -119,7 +119,7 @@ def _generate_embeddings_by_text_input_key( embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif "texts" == text_input_key: + elif text_input_key == "texts": result = client.run( replicate_model_version, input={ diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 5ff00f008eb621..b8c979b1f53ce9 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session: Any = None predictor: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None def _handle_chat_generate_response( self, @@ -209,8 +209,8 @@ def _invoke( :param user: unique user id :return: full response or stream response chunk generator result """ - from sagemaker import Predictor, serializers - from sagemaker.session import Session + from sagemaker import Predictor, serializers # type: ignore + from sagemaker.session import Session # type: ignore if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index df797bae265825..7daab6d8653d33 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -3,7 +3,7 @@ import operator from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -114,6 +114,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke rerank model, model: {model}") + raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}") def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 2d50e9c7b4c28a..a6aca130456063 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -2,7 +2,7 @@ import logging from typing import IO, Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -67,6 +67,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional s3_prefix = "dify/speech2text/" sagemaker_endpoint = credentials.get("sagemaker_endpoint") bucket = credentials.get("audio_s3_cache_bucket") + assert bucket is not None, "audio_s3_cache_bucket is required in credentials" s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) payload = {"audio_s3_presign_uri": s3_presign_url} diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index ef4ddcd6a72847..e7eccd997d11c1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import time from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -118,6 +118,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}") + raise InvokeError(str(e)) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 6a5946453be07f..62231c518deef1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional -import boto3 +import boto3 # type: ignore import requests from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index b52df3e4e3fdee..8703a97edd1133 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -1,4 +1,5 @@ - Tencent/Hunyuan-A52B-Instruct +- Qwen/QwQ-32B-Preview - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen2.5-32B-Instruct - Qwen/Qwen2.5-14B-Instruct @@ -19,6 +20,7 @@ - 01-ai/Yi-1.5-6B-Chat - internlm/internlm2_5-20b-chat - internlm/internlm2_5-7b-chat +- meta-llama/Llama-3.3-70B-Instruct - meta-llama/Meta-Llama-3.1-405B-Instruct - meta-llama/Meta-Llama-3.1-70B-Instruct - meta-llama/Meta-Llama-3.1-8B-Instruct diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index e3a323a4965bc7..f61e8b82e4db99 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -43,7 +43,7 @@ def _add_custom_parameters(cls, credentials: dict) -> None: credentials["mode"] = "chat" credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-llama-3.3-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-llama-3.3-70b-instruct.yaml new file mode 100644 index 00000000000000..9373a8f4ca9f4a --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-llama-3.3-70b-instruct.yaml @@ -0,0 +1,53 @@ +model: meta-llama/Llama-3.3-70B-Instruct +label: + en_US: meta-llama/Llama-3.3-70B-Instruct +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml new file mode 100644 index 00000000000000..c949de4d75604c --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml @@ -0,0 +1,53 @@ +model: Qwen/QwQ-32B-Preview +label: + en_US: Qwen/QwQ-32B-Preview +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '1.26' + output: '1.26' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index e121ab8c7e4e2f..03c4306144a651 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -18,7 +18,7 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 1181ba699af886..cb6f28b6c27fa9 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,6 +1,6 @@ import threading from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -270,7 +270,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = cast(str, content) else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b96d43979ef54a..03eac194235e83 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -12,6 +12,7 @@ AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -67,7 +68,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode cred_with_endpoint = self._update_endpoint_url(credentials=credentials) REPETITION_PENALTY = "repetition_penalty" TOP_K = "top_k" - features = [] + features: list[ModelFeature] = [] entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index 8a50c7aa05f38c..bb68319555007f 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index cde5d214d04d97..61ebd45ed64a6d 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -7,9 +7,9 @@ from pathlib import Path from typing import Optional, Union, cast -from dashscope import Generation, MultiModalConversation, get_tokenizer -from dashscope.api_entities.dashscope_response import GenerationResponse -from dashscope.common.error import ( +from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore +from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, @@ -434,9 +434,9 @@ def _convert_prompt_messages_to_tongyi_messages( sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.VIDEO: message_content = cast(VideoPromptMessageContent, message_content) - video_url = message_content.data - if message_content.data.startswith("data:"): - raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url") + video_url = message_content.url + if not video_url: + raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url") sub_message_dict = {"video": video_url} sub_messages.append(sub_message_dict) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml index 94b6666d0569fe..5970fec5e6d4fa 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml @@ -59,8 +59,6 @@ parameter_rules: help: zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - - name: response_format - use_template: response_format - name: repetition_penalty required: false type: float diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml index b6172c1cbc3d06..c5c8dc5f7c4963 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml @@ -59,8 +59,6 @@ parameter_rules: help: zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - - name: response_format - use_template: response_format - name: repetition_penalty required: false type: float diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml index 03cb039d15a7dd..9d9d6c6d11141f 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml @@ -58,8 +58,6 @@ parameter_rules: help: zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - - name: response_format - use_template: response_format - name: repetition_penalty required: false type: float diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml index 0be4b68f4f93ad..2fab6db648e722 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml @@ -59,8 +59,6 @@ parameter_rules: help: zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - - name: response_format - use_template: response_format - name: repetition_penalty required: false type: float diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml index 6c8a8121c6d243..61820ca8538d29 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml @@ -59,8 +59,6 @@ parameter_rules: help: zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - - name: response_format - use_template: response_format - name: repetition_penalty required: false type: float diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py index a5ce9ead6ee3be..ed682cb0f3c1e4 100644 --- a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -1,7 +1,7 @@ from typing import Optional -import dashscope -from dashscope.common.error import ( +import dashscope # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, @@ -51,7 +51,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client dashscope.api_key = credentials["dashscope_api_key"] @@ -64,7 +64,7 @@ def _invoke( return_documents=True, ) - rerank_documents = [] + rerank_documents: list[RerankDocument] = [] if not response.output: return RerankResult(model=model, docs=rerank_documents) for _, result in enumerate(response.output.results): diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 2ef7f3f5774481..8c53be413002a9 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -import dashscope +import dashscope # type: ignore import numpy as np from core.entities.embedding_type import EmbeddingInputType diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index ca3b9fbc1c3c00..a654e2d760d7c4 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -2,10 +2,10 @@ from queue import Queue from typing import Any, Optional -import dashscope -from dashscope import SpeechSynthesizer -from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse -from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult +import dashscope # type: ignore +from dashscope import SpeechSynthesizer # type: ignore +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore +from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 47ebaccd84ab8a..f6609bba77129b 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonUpstage: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index a18ee906248a49..2bf6796ca5cf45 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -6,7 +6,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 7dd495b55ef4e6..87693eca768dfd 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -1,11 +1,10 @@ import base64 import time -from collections.abc import Mapping from typing import Union import numpy as np from openai import OpenAI -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType @@ -100,7 +99,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) @@ -129,7 +131,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py index 8f7c859e3803c0..4e3df7574e9ce8 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/_common.py +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -12,4 +12,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - pass + raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml new file mode 100644 index 00000000000000..bcd59623a78e43 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml @@ -0,0 +1,39 @@ +model: gemini-2.0-flash-exp +label: + en_US: Gemini 2.0 Flash Exp +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call + - document +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 1469de605525ef..85be34f3f0fe7f 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -4,11 +4,10 @@ import logging import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast -import google.auth.transport.requests +import google.auth.transport.requests # type: ignore import requests -import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream from anthropic.types import ( ContentBlockDeltaEvent, @@ -19,8 +18,6 @@ MessageStreamEvent, ) from google.api_core import exceptions -from google.cloud import aiplatform -from google.oauth2 import service_account from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -47,6 +44,9 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +if TYPE_CHECKING: + import vertexai.generative_models as glm + logger = logging.getLogger(__name__) @@ -102,15 +102,18 @@ def _generate_anthropic( :param stream: is stream response :return: full response or stream response chunk generator result """ + from google.oauth2 import service_account + # use Anthropic official SDK references # - https://github.com/anthropics/anthropic-sdk-python - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] token = "" # get access token from service account credential - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES) request = google.auth.transport.requests.Request() credentials.refresh(request) @@ -405,13 +408,15 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: return text.rstrip() - def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool": """ Convert tool messages to glm tools :param tools: tool messages :return: glm tools """ + import vertexai.generative_models as glm + return glm.Tool( function_declarations=[ glm.FunctionDeclaration( @@ -472,16 +477,21 @@ def _generate( :param user: unique user id :return: full response or stream response chunk generator result """ + import vertexai.generative_models as glm + from google.cloud import aiplatform + from google.oauth2 import service_account + config_kwargs = model_parameters.copy() config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: @@ -520,7 +530,7 @@ def _generate( return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage] ) -> LLMResult: """ Handle llm response @@ -552,7 +562,7 @@ def _handle_generate_response( return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage] ) -> Generator: """ Handle llm stream response @@ -636,13 +646,15 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: return message_text - def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: + def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content": """ Format a single message into glm.Content for Google API :param message: one PromptMessage :return: glm Content representation of message """ + import vertexai.generative_models as glm + if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 9cd0c78d99df24..b8b0e5f15acb44 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -2,12 +2,9 @@ import json import time from decimal import Decimal -from typing import Optional +from typing import TYPE_CHECKING, Optional import tiktoken -from google.cloud import aiplatform -from google.oauth2 import service_account -from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -24,6 +21,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi +if TYPE_CHECKING: + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel +else: + VertexTextEmbeddingModel = None + class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ @@ -48,10 +50,15 @@ def _invoke( :param input_type: input type :return: embeddings result """ - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + from google.cloud import aiplatform + from google.oauth2 import service_account + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: @@ -99,11 +106,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ + from google.cloud import aiplatform + from google.oauth2 import service_account + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + try: - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py index 034c066ab5f071..782e4fd6232a3b 100644 --- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -17,14 +17,12 @@ class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - features = [] - entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, + features=[], model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), }, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index cfe21e4b9f4617..a8a015167e3227 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,9 +1,8 @@ -import re from collections.abc import Generator from typing import Optional, cast -from volcenginesdkarkruntime import Ark -from volcenginesdkarkruntime.types.chat import ( +from volcenginesdkarkruntime import Ark # type: ignore +from volcenginesdkarkruntime.types.chat import ( # type: ignore ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, @@ -16,10 +15,10 @@ ChatCompletionToolParam, ChatCompletionUserMessageParam, ) -from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL -from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function -from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse -from volcenginesdkarkruntime.types.shared_params import FunctionDefinition +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -104,17 +103,16 @@ def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam if message_content.type == PromptMessageContentType.TEXT: content.append( ChatCompletionContentPartTextParam( - text=message_content.text, + text=message_content.data, type="text", ) ) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) - image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) content.append( ChatCompletionContentPartImageParam( image_url=ImageURL( - url=image_data, + url=message_content.data, detail=message_content.detail.value, ), type="image_url", diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 266f1216f82b29..0c61e19f066254 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -68,7 +68,12 @@ def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError("Content object type only support image_url") + content.append( + { + "type": "text", + "text": message_content.data, + } + ) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 91dbe21a616195..aa837b8318873d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -152,5 +152,6 @@ class ServiceNotOpenError(MaasError): def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): - return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + # FIXME: mypy type error, try to fix it instead of using type: ignore + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 1c776cec7e3096..f0b2b101b7be9d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -2,7 +2,7 @@ from collections.abc import Generator from typing import Optional -from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -132,6 +132,14 @@ def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): + # Ignore tokens for image type + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text num_tokens += self._get_num_tokens_by_gpt2(str(key)) num_tokens += self._get_num_tokens_by_gpt2(str(value)) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index d8be14b0247698..7c37368086e0e6 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMMode @@ -16,6 +18,14 @@ class ModelConfig(BaseModel): configs: dict[str, ModelConfig] = { + "Doubao-vision-pro-32k": ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.VISION], + ), + "Doubao-vision-lite-32k": ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.VISION], + ), "Doubao-pro-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL], @@ -32,6 +42,10 @@ class ModelConfig(BaseModel): properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL], ), + "Doubao-pro-256k": ModelConfig( + properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT), + features=[], + ), "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL], @@ -90,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig: def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: @@ -118,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 4a6f5b6f7bc7cd..be9bba5f2450d2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -12,6 +12,7 @@ class ModelConfig(BaseModel): ModelConfigs = { "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), + "Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } @@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig: if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get("context_size", 0)), + context_size=int(credentials.get("context_size", 4096)), max_chunks=int(credentials.get("max_chunks", 1)), ) ) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index 13e00da76fb149..2ddb612546690c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -118,6 +118,18 @@ model_credential_schema: type: select required: true options: + - label: + en_US: Doubao-vision-pro-32k + value: Doubao-vision-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-vision-lite-32k + value: Doubao-vision-lite-32k + show_on: + - variable: __model_type + value: llm - label: en_US: Doubao-pro-4k value: Doubao-pro-4k @@ -154,6 +166,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: Doubao-pro-256k + value: Doubao-pro-256k + show_on: + - variable: __model_type + value: llm - label: en_US: Llama3-8B value: Llama3-8B @@ -208,6 +226,12 @@ model_credential_schema: show_on: - variable: __model_type value: text-embedding + - label: + en_US: Doubao-embedding-large + value: Doubao-embedding-large + show_on: + - variable: __model_type + value: text-embedding - label: en_US: Custom zh_Hans: 自定义 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 07b970f8104c8f..d2899795696aa4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post @@ -22,7 +22,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -135,6 +135,7 @@ def _build_function_calling_request_body( """ TODO: implement function calling """ + raise NotImplementedError("Function calling is not supported yet.") def _build_chat_request_body( self, diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 19135deb27380d..816b3b98c4b8c5 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -1,6 +1,5 @@ import time from abc import abstractmethod -from collections.abc import Mapping from json import dumps from typing import Any, Optional @@ -23,12 +22,12 @@ class TextEmbedding: @abstractmethod - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: raise NotImplementedError class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: access_token = self._get_access_token() url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) @@ -50,7 +49,7 @@ def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> } return body - def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]: data = response.json() if "error_code" in data: code = data["error_code"] @@ -147,7 +146,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: api_key = credentials["api_key"] secret_key = credentials["secret_key"] try: diff --git a/api/core/model_runtime/model_providers/x/llm/grok-2-1212.yaml b/api/core/model_runtime/model_providers/x/llm/grok-2-1212.yaml new file mode 100644 index 00000000000000..24ea716a985d32 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-2-1212.yaml @@ -0,0 +1,66 @@ +model: grok-2-1212 +label: + en_US: grok-2-1212 +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/grok-2-vision-1212.yaml b/api/core/model_runtime/model_providers/x/llm/grok-2-vision-1212.yaml new file mode 100644 index 00000000000000..f224fa57573c42 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-2-vision-1212.yaml @@ -0,0 +1,64 @@ +model: grok-2-vision-1212 +label: + en_US: grok-2-vision-1212 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml index bb71de2badb335..7f722539d9cfed 100644 --- a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -1,6 +1,6 @@ model: grok-beta label: - en_US: Grok Beta + en_US: grok-beta model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml index 844f0520bc64fb..1d8128253400d7 100644 --- a/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml +++ b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml @@ -1,6 +1,6 @@ model: grok-vision-beta label: - en_US: Grok Vision Beta + en_US: grok-vision-beta model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 8d86d6937d8ac9..7db1203641cad2 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -17,7 +17,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulGenerateModelHandle, diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index efaf114854b5c1..078ec0537a37f4 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,6 +1,6 @@ from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 3d7aefeb6dd89a..5f330ece1a5750 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -1,6 +1,6 @@ from typing import IO, Optional -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e51e6a941c5413..9054aabab2dd05 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -134,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: - raise InvokeAuthorizationError(e) + raise InvokeAuthorizationError(str(e)) if not isinstance(handle, RESTfulEmbeddingModelHandle): raise InvokeBadRequestError( diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index ad7b64efb5d2e7..8aa39d4de0d2cb 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,7 +1,7 @@ import concurrent.futures from typing import Any, Optional -from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -74,11 +74,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") + api_key = credentials.get("api_key") + if api_key is None: + raise CredentialsValidateFailedError("api_key is required") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], model_uid=credentials["model_uid"], - api_key=credentials.get("api_key"), + api_key=api_key, ) if "text-to-audio" not in extra_param.model_ability: diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index baa3ccbe8adbc0..b51423f4eda2e6 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,6 +1,6 @@ from threading import Lock from time import time -from typing import Optional +from typing import Any, Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -39,13 +39,15 @@ def __init__( self.model_family = model_family -cache = {} +cache: dict[str, dict[str, Any]] = {} cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: @@ -66,7 +68,9 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 0642e72ed500e1..f5b61e207635bc 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -136,7 +136,7 @@ def _add_custom_parameters(credentials: dict) -> None: parsed_url = urlparse(credentials["endpoint_url"]) credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml new file mode 100644 index 00000000000000..c2047b2cd32b83 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml @@ -0,0 +1,52 @@ +model: glm-4v-flash +label: + en_US: glm-4v-flash +model_type: llm +model_properties: + mode: chat + context_size: 2048 +features: + - vision +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.6 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index e0601d681cbf74..eef86cc52c36e8 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from zhipuai import ZhipuAI -from zhipuai.types.chat.chat_completion import Completion -from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk +from zhipuai import ZhipuAI # type: ignore +from zhipuai.types.chat.chat_completion import Completion # type: ignore +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -144,7 +144,7 @@ def _generate( if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in {"glm-4v", "glm-4v-plus"}: + if not model.startswith("glm-4v"): # not support list message continue # get image and @@ -188,7 +188,7 @@ def _generate( else: model_parameters["tools"] = [web_search_params] - if model in {"glm-4v", "glm-4v-plus"}: + if model.startswith("glm-4v"): params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = {"model": model, "messages": [], **model_parameters} @@ -412,6 +412,8 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: human_prompt = "\n\nHuman:" ai_prompt = "\n\nAssistant:" content = message.content + if isinstance(content, list): + content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 2428284ba9a8ff..a700304db7b6f3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 029ec1a581b2e9..8cc8adfc3656ea 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Union, cast from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType @@ -38,7 +38,7 @@ def _validate_and_filter_credential_form_schemas( def _validate_credential_form_schema( self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Optional[str]: + ) -> Union[str, bool, None]: """ Validate credential form schema @@ -47,6 +47,7 @@ def _validate_credential_form_schema( :return: validated credential form schema value """ # If the variable does not exist in credentials + value: Union[str, bool, None] = None if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: @@ -61,7 +62,7 @@ def _validate_credential_form_schema( return None # Get the value corresponding to the variable from credentials - value = credentials[credential_form_schema.variable] + value = cast(str, credentials[credential_form_schema.variable]) # If max_length=0, no validation is performed if credential_form_schema.max_length: diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index ec1bad5698f2eb..03e350627140cf 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,7 +129,8 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) + # FIXME: mypy error, try to fix it instead of using type: ignore + obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder( obj_dict, by_alias=by_alias, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 2067092d80f582..5e8a723ec7c510 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -4,6 +4,7 @@ def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): - return pydantic.model_dump(model) + # FIXME mypy error, try to fix it instead of using type: ignore + return pydantic.model_dump(model) # type: ignore else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 094ad7863603dc..c65a3885fd1eb9 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor @@ -43,6 +45,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) @@ -57,6 +61,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) @@ -69,14 +75,18 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + if self.config is None: + raise ValueError("The config is not set.") + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) result = requestor.request(extension_point, params) return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60898d5547ae3b..d8c392d0970e19 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_re if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response")) > 100: + if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response")) > 100: + if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 96bf2ab54b41eb..0ad4438c143870 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -22,7 +22,8 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - extension_class.validate_config(tenant_id, config) + # FIXME: mypy error, try to fix it instead of using type: ignore + extension_class.validate_config(tenant_id, config) # type: ignore def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46d3963bd07f5a..3ac33966cb14bf 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -17,11 +18,11 @@ def check( app_id: str, tenant_id: str, app_config: AppConfig, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -33,6 +34,7 @@ def check( :param trace_manager: trace manager :return: """ + inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 00b3c56c03602d..9dd2665c3bf3d3 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,7 +21,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 10000: + if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") keywords_row_len = config["keywords"].split("\n") @@ -31,6 +31,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -50,6 +52,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: # Filter out empty values diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6465de23b9a2de..d64f17b383e0b5 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -20,6 +20,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -35,6 +37,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 4635bd9c251851..e595be126c7824 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -70,7 +70,7 @@ def start_thread(self) -> threading.Thread: thread = threading.Thread( target=self.worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71ff03b6ef5160..f0e34c0cd71241 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_id: str workflow_run_elapsed_time: Union[int, float] workflow_run_status: str - workflow_run_inputs: dict[str, Any] - workflow_run_outputs: dict[str, Any] + workflow_run_inputs: Mapping[str, Any] + workflow_run_outputs: Mapping[str, Any] workflow_run_version: str error: Optional[str] = None total_tokens: int diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 0cba40c51a0d19..b9ba068b19936d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import Optional -from langfuse import Langfuse +from langfuse import Langfuse # type: ignore from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -65,8 +65,11 @@ def trace(self, trace_info: BaseTraceInfo): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id + trace_id = trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") + metadata = trace_info.metadata + metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id + if trace_info.message_id: trace_id = trace_info.message_id name = TraceTaskName.MESSAGE_TRACE.value @@ -74,24 +77,22 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=name, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, - metadata=trace_info.metadata, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), + metadata=metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], - created_at=trace_info.start_time, - updated_at=trace_info.end_time, ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), + id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - metadata=trace_info.metadata, + metadata=metadata, level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, status_message=trace_info.error or "", ) @@ -101,9 +102,9 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, - metadata=trace_info.metadata, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), + metadata=metadata, session_id=trace_info.conversation_id, tags=["workflow"], ) @@ -192,7 +193,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error or "", - parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), + parent_observation_id=trace_info.workflow_run_id, ) else: span_data = LangfuseSpan( @@ -239,11 +240,13 @@ def message_trace(self, trace_info: MessageTraceInfo, **kwargs): file_list = trace_info.file_list metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -300,6 +303,8 @@ def message_trace(self, trace_info: MessageTraceInfo, **kwargs): self.add_generation(langfuse_generation_data) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return span_data = LangfuseSpan( name=TraceTaskName.MODERATION_TRACE.value, input=trace_info.inputs, @@ -319,9 +324,11 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return generation_usage = GenerationUsage( total=len(str(trace_info.suggested_question)), - input=len(trace_info.inputs), + input=len(trace_info.inputs) if trace_info.inputs else 0, output=len(trace_info.suggested_question), unit=UnitEnum.CHARACTERS, ) @@ -342,6 +349,8 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): self.add_generation(langfuse_generation_data=generation_data) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_span_data = LangfuseSpan( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 99221d669b3193..348b7ba5012b6b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index c15b132abd30a2..4ffd888bddf8a3 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -3,6 +3,7 @@ import os import uuid from datetime import datetime, timedelta +from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase @@ -62,53 +63,73 @@ def trace(self, trace_info: BaseTraceInfo): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id + trace_id = trace_info.message_id or trace_info.workflow_run_id + if trace_info.start_time is None: + trace_info.start_time = datetime.now() message_dotted_order = ( generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None ) workflow_dotted_order = generate_dotted_order( - trace_info.workflow_app_log_id or trace_info.workflow_run_id, + trace_info.workflow_run_id, trace_info.workflow_data.created_at, message_dotted_order, ) + metadata = trace_info.metadata + metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id if trace_info.message_id: message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE.value, - inputs=trace_info.workflow_run_inputs, - outputs=trace_info.workflow_run_outputs, + inputs=dict(trace_info.workflow_run_inputs), + outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, start_time=trace_info.start_time, end_time=trace_info.end_time, extra={ - "metadata": trace_info.metadata, + "metadata": metadata, }, tags=["message", "workflow"], error=trace_info.error, trace_id=trace_id, dotted_order=message_dotted_order, + file_list=[], + serialized=None, + parent_run_id=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(message_run) langsmith_run = LangSmithRunModel( file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, - id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, + id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - inputs=trace_info.workflow_run_inputs, + inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, end_time=trace_info.workflow_data.finished_at, - outputs=trace_info.workflow_run_outputs, + outputs=dict(trace_info.workflow_run_outputs), extra={ - "metadata": trace_info.metadata, + "metadata": metadata, }, error=trace_info.error, tags=["workflow"], parent_run_id=trace_info.message_id or None, trace_id=trace_id, dotted_order=workflow_dotted_order, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) @@ -204,30 +225,40 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_run_id, tags=["node_execution"], id=node_execution_id, trace_id=trace_id, dotted_order=node_dotted_order, + error="", + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) def message_trace(self, trace_info: MessageTraceInfo): # get message file data - file_list = trace_info.file_list - message_file_data: MessageFile = trace_info.message_file_data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -245,12 +276,20 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, tags=["message", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + parent_run_id=None, ) self.add_run(message_run) @@ -265,17 +304,27 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, parent_run_id=message_id, tags=["llm", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + id=str(uuid.uuid4()), ) self.add_run(llm_run) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return langsmith_run = LangSmithRunModel( name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, @@ -286,48 +335,82 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): "inputs": trace_info.inputs, }, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["moderation"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(langsmith_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return suggested_question_run = LangSmithRunModel( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["suggested_question"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or message_data.created_at, end_time=trace_info.end_time or message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_run = LangSmithRunModel( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["dataset_retrieval"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(dataset_retrieval_run) @@ -345,7 +428,18 @@ def tool_trace(self, trace_info: ToolTraceInfo): parent_run_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - file_list=[trace_info.file_url], + file_list=[cast(str, trace_info.file_url)], + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error=trace_info.error or "", ) self.add_run(tool_run) @@ -356,12 +450,23 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["generate_name"], start_time=trace_info.start_time or datetime.now(), end_time=trace_info.end_time or datetime.now(), + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], + parent_run_id=None, ) self.add_run(name_run) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index b7799ce1fbdd5e..f538eaef5bd570 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -33,11 +33,11 @@ from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage -from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig +from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -provider_config_map = { +provider_config_map: dict[str, dict[str, Any]] = { TracingProviderEnum.LANGFUSE.value: { "config_class": LangfuseConfig, "secret_keys": ["public_key", "secret_key"], @@ -145,7 +145,7 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -155,7 +155,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") + + tenant_id = app.tenant_id decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -178,7 +182,7 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -209,8 +213,12 @@ def get_ops_trace_instance( def get_app_config_through_message_id(cls, message_id: str): app_model_config = None message_data = db.session.query(Message).filter(Message.id == message_id).first() + if not message_data: + return None conversation_id = message_data.conversation_id conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation_data: + return None if conversation_data.app_model_config_id: app_model_config = ( @@ -236,7 +244,9 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app_config: + raise ValueError("App not found") app_config.tracing = json.dumps( { "enabled": enabled, @@ -252,7 +262,9 @@ def get_app_tracing_config(cls, app_id: str): :param app_id: app id :return: """ - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) @@ -355,7 +367,13 @@ def preprocess(self): def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): + def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): + if not workflow_run: + raise ValueError("Workflow run not found") + + db.session.merge(workflow_run) + db.session.refresh(workflow_run) + workflow_id = workflow_run.workflow_id tenant_id = workflow_run.tenant_id workflow_run_id = workflow_run.id @@ -477,6 +495,8 @@ def message_trace(self, message_id): def moderation_trace(self, message_id, timer, **kwargs): moderation_result = kwargs.get("moderation_result") + if not moderation_result: + return {} inputs = kwargs.get("inputs") message_data = get_message_data(message_id) if not message_data: @@ -512,7 +532,7 @@ def moderation_trace(self, message_id, timer, **kwargs): return moderation_trace_info def suggested_question_trace(self, message_id, timer, **kwargs): - suggested_question = kwargs.get("suggested_question") + suggested_question = kwargs.get("suggested_question", []) message_data = get_message_data(message_id) if not message_data: return {} @@ -580,7 +600,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -590,9 +610,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get("tool_name") - tool_inputs = kwargs.get("tool_inputs") - tool_outputs = kwargs.get("tool_outputs") + tool_name = kwargs.get("tool_name", "") + tool_inputs = kwargs.get("tool_inputs", {}) + tool_outputs = kwargs.get("tool_outputs", {}) message_data = get_message_data(message_id) if not message_data: return {} @@ -602,7 +622,7 @@ def tool_trace(self, message_id, timer, **kwargs): tool_parameters = {} created_time = message_data.created_at end_time = message_data.updated_at - agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts + agent_thoughts = message_data.agent_thoughts for agent_thought in agent_thoughts: if tool_name in agent_thought.tools: created_time = agent_thought.created_at @@ -666,6 +686,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): generate_conversation_name = kwargs.get("generate_conversation_name") inputs = kwargs.get("inputs") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + return {} start_time = timer.get("start") end_time = timer.get("end") @@ -687,8 +709,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): return generate_name_trace_info -trace_manager_timer = None -trace_manager_queue = queue.Queue() +trace_manager_timer: Optional[threading.Timer] = None +trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) @@ -700,7 +722,7 @@ def __init__(self, app_id=None, user_id=None): self.app_id = app_id self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - self.flask_app = current_app._get_current_object() + self.flask_app = current_app._get_current_object() # type: ignore if trace_manager_timer is None: self.start_timer() @@ -717,7 +739,7 @@ def add_trace_task(self, trace_task: TraceTask): def collect_tasks(self): global trace_manager_queue - tasks = [] + tasks: list[TraceTask] = [] while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): task = trace_manager_queue.get_nowait() tasks.append(task) @@ -743,6 +765,8 @@ def start_timer(self): def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + if task.app_id is None: + continue file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f3f8249661bf0..87c7a79fb01201 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -39,7 +39,7 @@ def get_prompt( self, *, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, - inputs: dict[str, str], + inputs: Mapping[str, str], query: str, files: Sequence[File], context: Optional[str], @@ -77,7 +77,7 @@ def get_prompt( def _get_completion_model_prompt_messages( self, prompt_template: CompletionModelPromptTemplate, - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -90,15 +90,15 @@ def _get_completion_model_prompt_messages( """ raw_prompt = prompt_template.text - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - if memory and memory_config: + if memory and memory_config and memory_config.role_prefix: role_prefix = memory_config.role_prefix prompt_inputs = self._set_histories_variable( memory=memory, @@ -135,7 +135,7 @@ def _get_completion_model_prompt_messages( def _get_chat_model_prompt_messages( self, prompt_template: list[ChatModelMessage], - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -146,7 +146,7 @@ def _get_chat_model_prompt_messages( """ Get chat model prompt messages. """ - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text @@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages( prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable( context=context, parser=parser, prompt_inputs=prompt_inputs ) @@ -207,7 +207,7 @@ def _get_chat_model_prompt_messages( last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append(file_manager.to_prompt_message_content(file)) @@ -229,7 +229,10 @@ def _get_chat_model_prompt_messages( return prompt_messages - def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_context_variable( + self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context @@ -238,7 +241,10 @@ def _set_context_variable(self, context: str | None, parser: PromptTemplateParse return prompt_inputs - def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_query_variable( + self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query @@ -254,9 +260,10 @@ def _set_histories_variable( raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, - prompt_inputs: dict, + prompt_inputs: Mapping[str, str], model_config: ModelConfigWithCredentialsEntity, - ) -> dict: + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index caa1793ea8c039..09f017a7db0d3a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -31,7 +31,7 @@ def __init__( self.memory = memory def get_prompt(self) -> list[PromptMessage]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] num_system = 0 for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 87acdb3c49cc01..1f040599be6dac 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -42,7 +42,7 @@ def _calculate_rest_token( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -59,7 +59,7 @@ def _get_history_messages_from_memory( ai_prefix: Optional[str] = None, ) -> str: """Get memory messages.""" - kwargs = {"max_token_limit": max_token_limit} + kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} if human_prefix: kwargs["human_prefix"] = human_prefix @@ -76,11 +76,15 @@ def _get_history_messages_list_from_memory( self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int ) -> list[PromptMessage]: """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=memory_config.window.size - if ( - memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + return list( + memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if ( + memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0 + ) + else None, ) - else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 93dd92f188a9c6..e75877de9b695c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,8 @@ import enum import json import os -from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -41,7 +42,7 @@ def value_of(cls, value: str) -> "ModelMode": raise ValueError(f"invalid mode value {value}") -prompt_file_contents = {} +prompt_file_contents: dict[str, Any] = {} class SimplePromptTransform(PromptTransform): @@ -53,9 +54,9 @@ def get_prompt( self, app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, - inputs: dict, + inputs: Mapping[str, str], query: str, - files: list["File"], + files: Sequence["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -66,7 +67,7 @@ def get_prompt( if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -77,7 +78,7 @@ def get_prompt( else: prompt_messages, stops = self._get_completion_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -171,11 +172,11 @@ def _get_chat_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] # get prompt prompt, _ = self.get_prompt_str_and_rules( @@ -216,7 +217,7 @@ def _get_completion_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -263,7 +264,7 @@ def _get_completion_model_prompt_messages( return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -288,7 +289,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return prompt_file_contents[prompt_file_name] + return cast(dict, prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -301,7 +302,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return content + return cast(dict, content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index aa175153bc633f..2f4e65146131be 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_runtime.entities import ( AssistantPromptMessage, @@ -72,7 +72,7 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) prompt = {"role": role, "text": text, "files": files} @@ -99,9 +99,9 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) - params = { + params: dict[str, Any] = { "role": "user", "text": text, } diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 0fd08c5d3c1a3e..8e40674bc193e0 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") WITH_VARIABLE_TMPL_REGEX = re.compile( @@ -28,7 +29,7 @@ def extract(self) -> list: # Regular expression to match the template rules return re.findall(self.regex, self.template) - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str: def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3a1fe300dfd311..010abd12d275cd 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, cast from sqlalchemy.exc import IntegrityError @@ -15,6 +15,7 @@ ModelLoadBalancingConfiguration, ModelSettings, QuotaConfiguration, + QuotaUnit, SystemConfiguration, ) from core.helper import encrypter @@ -116,8 +117,8 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), data=provider_entity, name_func=lambda x: x.provider, ): @@ -490,12 +491,13 @@ def _init_trial_provider_records( # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: + # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic provider_record = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, + quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, ) @@ -589,7 +591,9 @@ def _to_custom_configuration( if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable) or "", # type: ignore + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass @@ -671,13 +675,9 @@ def _to_system_configuration( # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if ( - provider_entity.provider not in hosting_configuration.provider_map - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled - ): - return SystemConfiguration(enabled=False) - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) + if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: + return SystemConfiguration(enabled=False) # Convert provider_records to dict quota_type_to_provider_records_dict = {} @@ -688,14 +688,13 @@ def _to_system_configuration( quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=0, quota_limit=0, is_valid=False, @@ -708,7 +707,7 @@ def _to_system_configuration( quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used @@ -725,12 +724,12 @@ def _to_system_configuration( current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) + provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record: + if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, - identity_id=provider_record.id, + identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) @@ -763,7 +762,7 @@ def _to_system_configuration( except ValueError: pass - current_using_credentials = provider_credentials + current_using_credentials = provider_credentials or {} # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) @@ -842,7 +841,7 @@ def _to_model_settings( else [] ) - model_settings = [] + model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 992415657eced2..d17d76333ee705 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -83,11 +83,15 @@ def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[d if reranking_model: try: model_manager = ModelManager() + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") + if not reranking_provider_name or not reranking_model_name: + return None rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], + provider=reranking_provider_name, model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], + model=reranking_model_name, ) return rerank_model_instance except InvokeAuthorizationError: diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a0153c1e58a1a8..95a2316f1da4dd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -32,8 +32,11 @@ def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) @@ -58,20 +61,26 @@ def add_texts(self, texts: list[Document], **kwargs): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() + if keyword_table is None: + return False return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + if keyword_table is not None: + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) @@ -80,7 +89,7 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("top_k", 4) - sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: @@ -137,7 +146,7 @@ def _get_dataset_keyword_table(self) -> Optional[dict]: if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict["__data__"]["table"] + return dict(keyword_table_dict["__data__"]["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( @@ -188,8 +197,8 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): # go through text chunks in order of most matching keywords chunk_indices_count: dict[str, int] = defaultdict(int) - keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] - for keyword in keywords: + keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords_list: for node_id in keyword_table[keyword]: chunk_indices_count[node_id] += 1 @@ -215,7 +224,7 @@ def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) def multi_create_segment_keywords(self, pre_segment_data_list: list): @@ -226,17 +235,19 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): if pre_segment_data["keywords"]: segment.keywords = pre_segment_data["keywords"] keyword_table = self._add_text_to_keyword_table( - keyword_table, segment.index_node_id, pre_segment_data["keywords"] + keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, segment.index_node_id, list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 4b1ade8e3fa095..8b17e8dc0a3762 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,18 +1,19 @@ import re from typing import Optional -import jieba -from jieba.analyse import default_tfidf - -from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - class JiebaKeywordTableHandler: def __init__(self): - default_tfidf.stop_words = STOPWORDS + import jieba.analyse # type: ignore + + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + jieba.analyse.default_tfidf.stop_words = STOPWORDS def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" + import jieba # type: ignore + keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, @@ -22,6 +23,8 @@ def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10 def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + results = set() for token in tokens: results.add(token) diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index be00687abd5025..b261b40b728692 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -37,6 +37,8 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: @@ -45,4 +47,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b2141396d6dcc4..34343ad60ea4c1 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,6 +6,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -31,7 +32,7 @@ def retrieve( top_k: int, score_threshold: Optional[float] = 0.0, reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = "reranking_model", + reranking_mode: str = "reranking_model", weights: Optional[dict] = None, ): if not query: @@ -42,15 +43,15 @@ def retrieve( if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] - all_documents = [] - threads = [] - exceptions = [] + all_documents: list[Document] = [] + threads: list[threading.Thread] = [] + exceptions: list[str] = [] # retrieval_model source with keyword if retrieval_method == "keyword_search": keyword_thread = threading.Thread( target=RetrievalService.keyword_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -65,7 +66,7 @@ def retrieve( embedding_thread = threading.Thread( target=RetrievalService.embedding_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -84,7 +85,7 @@ def retrieve( full_text_index_thread = threading.Thread( target=RetrievalService.full_text_index_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "retrieval_method": retrieval_method, @@ -103,7 +104,7 @@ def retrieve( if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( @@ -124,7 +125,7 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model or {} ) return all_documents @@ -135,6 +136,8 @@ def keyword_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) @@ -159,6 +162,8 @@ def embedding_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector = Vector(dataset=dataset) @@ -209,6 +214,8 @@ def full_text_index_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector_processor = Vector( dataset=dataset, diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 09104ae4223443..603d3fdbcdf1ab 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -17,12 +17,19 @@ class AnalyticdbVector(BaseVector): def __init__( - self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + self, + collection_name: str, + api_config: AnalyticdbVectorOpenAPIConfig | None, + sql_config: AnalyticdbVectorBySqlConfig | None, ): super().__init__(collection_name) if api_config is not None: - self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI( + collection_name, api_config + ) else: + if sql_config is None: + raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: @@ -33,8 +40,8 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - self.analyticdb_vector.add_texts(texts, embeddings) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(documents, embeddings) def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) @@ -68,13 +75,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings if dify_config.ANALYTICDB_HOST is None: # implemented through OpenAPI apiConfig = AnalyticdbVectorOpenAPIConfig( - access_key_id=dify_config.ANALYTICDB_KEY_ID, - access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, - region_id=dify_config.ANALYTICDB_REGION_ID, - instance_id=dify_config.ANALYTICDB_INSTANCE_ID, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, - namespace=dify_config.ANALYTICDB_NAMESPACE, + access_key_id=dify_config.ANALYTICDB_KEY_ID or "", + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "", + region_id=dify_config.ANALYTICDB_REGION_ID or "", + instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "", + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", + namespace=dify_config.ANALYTICDB_NAMESPACE or "", namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ) sqlConfig = None @@ -83,11 +90,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings sqlConfig = AnalyticdbVectorBySqlConfig( host=dify_config.ANALYTICDB_HOST, port=dify_config.ANALYTICDB_PORT, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, - namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace=dify_config.ANALYTICDB_NAMESPACE or "", ) apiConfig = None return AnalyticdbVector( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 05e0ebc54f7c4c..095752ea8eaa42 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: str = (None,) + namespace_password: Optional[str] = None metrics: str = "cosine" read_timeout: int = 60000 @@ -55,8 +55,8 @@ def to_analyticdb_client_params(self): class AnalyticdbVectorOpenAPI: def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models + from alibabacloud_gpdb20160503.client import Client # type: ignore + from alibabacloud_tea_openapi import models as open_api_models # type: ignore except: raise ImportError(_import_err_msg) self._collection_name = collection_name.lower() @@ -77,7 +77,7 @@ def _initialize(self) -> None: redis_client.set(database_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, @@ -89,7 +89,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException + from Tea.exceptions import TeaException # type: ignore try: request = gpdb_20160503_models.DescribeNamespaceRequest( @@ -159,17 +159,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, + if doc.metadata is not None: + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) ) - ) request = gpdb_20160503_models.UpsertCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -258,7 +259,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -290,7 +291,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index e474db5cb21971..4d8f7929413cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from core.rag.models.document import Document @@ -75,6 +75,7 @@ def _create_connection_pool(self): @contextmanager def _get_cursor(self): + assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() try: @@ -156,16 +157,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); """ for i, doc in enumerate(documents): - values.append( - ( - id_prefix + str(i), - doc.metadata.get("doc_id", str(uuid.uuid4())), - embeddings[i], - doc.page_content, - json.dumps(doc.metadata), - doc.page_content, + if doc.metadata is not None: + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_batch(cur, sql, values) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index eb78e8aa698b9b..85596ad20e099a 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -5,13 +5,13 @@ import numpy as np from pydantic import BaseModel, model_validator -from pymochow import MochowClient -from pymochow.auth.bce_credentials import BceCredentials -from pymochow.configuration import Configuration -from pymochow.exception import ServerError -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row +from pymochow import MochowClient # type: ignore +from pymochow.auth.bce_credentials import BceCredentials # type: ignore +from pymochow.configuration import Configuration # type: ignore +from pymochow.exception import ServerError # type: ignore +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -75,7 +75,7 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,6 +84,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] + assert len(metadatas) == total_count, "metadatas length should be equal to total_count" + # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), @@ -136,7 +138,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # baidu vector database doesn't support bm25 search on current version return [] - def _get_search_res(self, res, score_threshold): + def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) @@ -276,11 +278,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return BaiduVector( collection_name=collection_name, config=BaiduConfig( - endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "", connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, - account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, - api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, - database=dify_config.BAIDU_VECTOR_DB_DATABASE, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "", + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "", + database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, ), diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index a9e1486edd25f1..0eab01b507dc94 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -71,11 +71,13 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] collection = self._client.get_or_create_collection(self._collection_name) - collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) + # FIXME: chromadb using numpy array, fix the type error later + collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {"$eq": value}}) + # FIXME: fix the type error later + collection.delete(where={key: {"$eq": value}}) # type: ignore def delete(self): self._client.delete_collection(self._collection_name) @@ -94,15 +96,19 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results["ids"][0] - documents: list[str] = results["documents"][0] - metadatas: dict[str, Any] = results["metadatas"][0] - distances: list[float] = results["distances"][0] + # Check if results contain data + if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]: + return [] + + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + distances = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] - metadata = metadatas[index] + metadata = dict(metadatas[index]) if distance >= score_threshold: metadata["score"] = distance doc = Document( @@ -111,7 +117,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,7 +139,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=dify_config.CHROMA_HOST, + host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index d26726e86438bd..68a9952789e5b6 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -5,14 +5,14 @@ from datetime import timedelta from typing import Any -from couchbase import search -from couchbase.auth import PasswordAuthenticator -from couchbase.cluster import Cluster -from couchbase.management.search import SearchIndex +from couchbase import search # type: ignore +from couchbase.auth import PasswordAuthenticator # type: ignore +from couchbase.cluster import Cluster # type: ignore +from couchbase.management.search import SearchIndex # type: ignore # needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. -from couchbase.options import ClusterOptions, SearchOptions -from couchbase.vector_search import VectorQuery, VectorSearch +from couchbase.options import ClusterOptions, SearchOptions # type: ignore +from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore from flask import current_app from pydantic import BaseModel, model_validator @@ -231,7 +231,7 @@ def text_exists(self, id: str) -> bool: # Pass the id as a parameter to the query result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() for row in result: - return row["count"] > 0 + return bool(row["count"] > 0) return False # Return False if no rows are returned def delete_by_ids(self, ids: list[str]) -> None: @@ -369,10 +369,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return CouchbaseVector( collection_name=collection_name, config=CouchbaseConfig( - connection_string=config.get("COUCHBASE_CONNECTION_STRING"), - user=config.get("COUCHBASE_USER"), - password=config.get("COUCHBASE_PASSWORD"), - bucket_name=config.get("COUCHBASE_BUCKET_NAME"), - scope_name=config.get("COUCHBASE_SCOPE_NAME"), + connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""), + user=config.get("COUCHBASE_USER", ""), + password=config.get("COUCHBASE_PASSWORD", ""), + bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""), + scope_name=config.get("COUCHBASE_SCOPE_NAME", ""), ), ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index b08811a02181d2..8661828dc2aa52 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any, Optional, cast from urllib.parse import urlparse import requests @@ -70,7 +70,7 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _get_version(self) -> str: info = self._client.info() - return info["version"]["number"] + return cast(str, info["version"]["number"]) def _check_version(self): if self._version < "8.0.0": @@ -135,7 +135,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -156,12 +157,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -208,10 +212,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST"), - port=config.get("ELASTICSEARCH_PORT"), - username=config.get("ELASTICSEARCH_USERNAME"), - password=config.get("ELASTICSEARCH_PASSWORD"), + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), ), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 60a1a89f1a0b1e..d7a14207e9375a 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -1,13 +1,10 @@ import copy import json import logging -from collections.abc import Iterable from typing import Any, Optional from opensearchpy import OpenSearch -from opensearchpy.helpers import bulk from pydantic import BaseModel, model_validator -from tenacity import retry, stop_after_attempt, wait_fixed from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -23,11 +20,15 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger("lindorm").setLevel(logging.WARN) +ROUTING_FIELD = "routing_field" +UGC_INDEX_PREFIX = "ugc_index" + class LindormVectorStoreConfig(BaseModel): hosts: str username: Optional[str] = None password: Optional[str] = None + using_ugc: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -41,19 +42,29 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts, - } + params: dict[str, Any] = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params class LindormVectorStore(BaseVector): - def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): - super().__init__(collection_name.lower()) + def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs): + self._routing = None + self._routing_field = None + if using_ugc: + routing_value: str | None = kwargs.get("routing_value") + if routing_value is None: + raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") + self._routing = routing_value.lower() + self._routing_field = ROUTING_FIELD + ugc_index_name = collection_name + super().__init__(ugc_index_name.lower()) + else: + super().__init__(collection_name.lower()) self._client_config = config self._client = OpenSearch(**config.to_opensearch_params()) + self._using_ugc = using_ugc self.kwargs = kwargs def get_type(self) -> str: @@ -66,89 +77,40 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def refresh(self): self._client.indices.refresh(index=self._collection_name) - def __filter_existed_ids( - self, - texts: list[str], - metadatas: list[dict], - ids: list[str], - bulk_size: int = 1024, - ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch {batch_ids}") - return set() - - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget( - body={ - "docs": [ - {"_index": self._collection_name, "_id": id, "routing": routing} - for id, routing in zip(batch_ids, route_ids) - ] - }, - _source=False, - ) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch ids: {batch_ids}") - return set() - - if ids is None: - return texts, metadatas, ids - - if len(texts) != len(ids): - raise RuntimeError(f"texts {len(texts)} != {ids}") - - filtered_texts = [] - filtered_metadatas = [] - filtered_ids = [] - - def batch(iterable, n): - length = len(iterable) - for idx in range(0, length, n): - yield iterable[idx : min(idx + n, length)] - - for ids_batch, texts_batch, metadatas_batch in zip( - batch(ids, bulk_size), - batch(texts, bulk_size), - batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), - ): - existing_ids_set = __fetch_existing_ids(ids_batch) - for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): - if doc_id not in existing_ids_set: - filtered_texts.append(text) - filtered_ids.append(doc_id) - if metadatas is not None: - filtered_metadatas.append(metadata) - - return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): actions = [] uuids = self._get_uuids(documents) for i in range(len(documents)): - action = { - "_op_type": "index", - "_index": self._collection_name.lower(), - "_id": uuids[i], - "_source": { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, - }, + action_header = { + "index": { + "_index": self.collection_name.lower(), + "_id": uuids[i], + } } - actions.append(action) - bulk(self._client, actions) - self.refresh() + action_values: dict[str, Any] = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + } + if self._using_ugc: + action_header["index"]["routing"] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing + actions.append(action_header) + actions.append(action_values) + response = self._client.bulk(actions) + if response["errors"]: + for item in response["items"]: + print(f"{item['index']['status']}: {item['index']['error']['type']}") + else: + self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + } + if self._using_ugc: + query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -156,50 +118,62 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): - query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} - results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit["_id"] for hit in results["hits"]["hits"]] + ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_ids(ids) def delete_by_ids(self, ids: list[str]) -> None: + params = {} + if self._using_ugc: + params["routing"] = self._routing for id in ids: - if self._client.exists(index=self._collection_name, id=id): - self._client.delete(index=self._collection_name, id=id) + if self._client.exists(index=self._collection_name, id=id, params=params): + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.delete(index=self._collection_name, id=id, params=params) + self.refresh() else: logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") def delete(self) -> None: - try: + if self._using_ugc: + routing_filter_query = { + "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}} + } + self._client.delete_by_query(self._collection_name, body=routing_filter_query) + self.refresh() + else: if self._client.indices.exists(index=self._collection_name): self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) logger.info("Delete index success") else: logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") - except Exception as e: - logger.exception(f"Error occurred while deleting the index: {self._collection_name}") - raise e def text_exists(self, id: str) -> bool: try: - self._client.get(index=self._collection_name, id=id) + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.get(index=self._collection_name, id=id, params=params) return True except: return False def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Make sure query_vector is a list if not isinstance(query_vector, list): raise ValueError("query_vector should be a list of floats") - # Check whether query_vector is a floating-point number list if not all(isinstance(x, float) for x in query_vector): raise ValueError("All elements in query_vector should be floats") top_k = kwargs.get("top_k", 10) query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) try: - response = self._client.search(index=self._collection_name, body=query) + params = {} + if self._using_ugc: + params["routing"] = self._routing + response = self._client.search(index=self._collection_name, body=query, params=params) except Exception as e: logger.exception(f"Error executing vector search, query: {query}") raise @@ -220,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -232,7 +207,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match = kwargs.get("minimum_should_match", 0) top_k = kwargs.get("top_k", 10) filters = kwargs.get("filter") - routing = kwargs.get("routing") + routing = self._routing full_text_query = default_text_search_query( query_text=query, k=top_k, @@ -243,6 +218,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match=minimum_should_match, filters=filters, routing=routing, + routing_field=self._routing_field, ) response = self._client.search(index=self._collection_name, body=full_text_query) docs = [] @@ -265,17 +241,18 @@ def create_collection(self, dimension: int, **kwargs): logger.info(f"Collection {self._collection_name} already exists.") return if self._client.indices.exists(index=self._collection_name): - logger.info("{self._collection_name.lower()} already exists.") + logger.info(f"{self._collection_name.lower()} already exists.") + redis_client.set(collection_exist_cache_key, 1, ex=3600) return if len(self.kwargs) == 0 and len(kwargs) != 0: self.kwargs = copy.deepcopy(kwargs) vector_field = kwargs.pop("vector_field", Field.VECTOR.value) - shards = kwargs.pop("shards", 2) + shards = kwargs.pop("shards", 4) engine = kwargs.pop("engine", "lvector") - method_name = kwargs.pop("method_name", "hnsw") + method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE) + space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE) data_type = kwargs.pop("data_type", "float") - space_type = kwargs.pop("space_type", "cosinesimil") hnsw_m = kwargs.pop("hnsw_m", 24) hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) @@ -288,10 +265,10 @@ def create_collection(self, dimension: int, **kwargs): mapping = default_text_mapping( dimension, method_name, + space_type=space_type, shards=shards, engine=engine, data_type=data_type, - space_type=space_type, vector_field=vector_field, hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction, @@ -301,6 +278,7 @@ def create_collection(self, dimension: int, **kwargs): centroids_hnsw_m=centroids_hnsw_m, centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, centroids_hnsw_ef_search=centroids_hnsw_ef_search, + using_ugc=self._using_ugc, **kwargs, ) self._client.indices.create(index=self._collection_name.lower(), body=mapping) @@ -309,15 +287,20 @@ def create_collection(self, dimension: int, **kwargs): def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: - routing_field = kwargs.get("routing_field") excludes_from_source = kwargs.get("excludes_from_source") analyzer = kwargs.get("analyzer", "ik_max_word") text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) engine = kwargs["engine"] shard = kwargs["shards"] - space_type = kwargs["space_type"] + space_type = kwargs.get("space_type") + if space_type is None: + if method_name == "hnsw": + space_type = "l2" + else: + space_type = "cosine" data_type = kwargs["data_type"] vector_field = kwargs.get("vector_field", Field.VECTOR.value) + using_ugc = kwargs.get("using_ugc", False) if method_name == "ivfpq": ivfpq_m = kwargs["ivfpq_m"] @@ -366,13 +349,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic if excludes_from_source: mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} - if method_name == "ivfpq" and routing_field is not None: + if using_ugc and method_name == "ivfpq": mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn.offline.construction"] = True - - if method_name == "flat" and routing_field is not None: + elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat": mapping["settings"]["index"]["knn_routing"] = True - return mapping @@ -386,14 +367,13 @@ def default_text_search_query( minimum_should_match: int = 0, filters: Optional[list[dict]] = None, routing: Optional[str] = None, + routing_field: Optional[str] = None, **kwargs, ) -> dict: + query_clause: dict[str, Any] = {} if routing is not None: - routing_field = kwargs.get("routing_field", "routing_field") query_clause = { - "bool": { - "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] - } + "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} } else: query_clause = {"match": {text_field: query_text}} @@ -411,7 +391,7 @@ def default_text_search_query( else: must = [query_clause] - boolean_query = {"must": must} + boolean_query: dict[str, Any] = {"must": must} if must_not: if not isinstance(must_not, list): @@ -449,9 +429,9 @@ def default_vector_search_query( ) -> dict: if filters is not None: filter_type = "post_filter" if filter_type is None else filter_type - if not isinstance(filter, list): + if not isinstance(filters, list): raise RuntimeError(f"unexpected filter with {type(filters)}") - final_ext = {"lvector": {}} + final_ext: dict[str, Any] = {"lvector": {}} if min_score != "0.0": final_ext["lvector"]["min_score"] = min_score if ef_search: @@ -463,7 +443,7 @@ def default_vector_search_query( if client_refactor: final_ext["lvector"]["client_refactor"] = client_refactor - search_query = { + search_query: dict[str, Any] = { "size": k, "_source": True, # force return '_source' "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, @@ -471,8 +451,8 @@ def default_vector_search_query( if filters is not None: # when using filter, transform filter from List[Dict] to Dict as valid format - filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict if filter_type: final_ext["lvector"]["filter_type"] = filter_type @@ -483,16 +463,47 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] - collection_name = class_prefix - else: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) lindorm_config = LindormVectorStoreConfig( - hosts=dify_config.LINDORM_URL, + hosts=dify_config.LINDORM_URL or "", username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, + using_ugc=dify_config.USING_UGC_INDEX, ) - return LindormVectorStore(collection_name, lindorm_config) + using_ugc = dify_config.USING_UGC_INDEX + if using_ugc is None: + raise ValueError("USING_UGC_INDEX is not set") + routing_value = None + if dataset.index_struct: + # if an existed record's index_struct_dict doesn't contain using_ugc field, + # it actually stores in the normal index format + stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False) + using_ugc = stored_in_ugc + if stored_in_ugc: + dimension = dataset.index_struct_dict["dimension"] + index_type = dataset.index_struct_dict["index_type"] + distance_type = dataset.index_struct_dict["distance_type"] + routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"] + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + else: + index_name = dataset.index_struct_dict["vector_store"]["class_prefix"] + else: + embedding_vector = embeddings.embed_query("hello word") + dimension = len(embedding_vector) + index_type = dify_config.DEFAULT_INDEX_TYPE + distance_type = dify_config.DEFAULT_DISTANCE_TYPE + class_prefix = Dataset.gen_collection_name_by_id(dataset.id) + index_struct_dict = { + "type": VectorType.LINDORM, + "vector_store": {"class_prefix": class_prefix}, + "index_type": index_type, + "dimension": dimension, + "distance_type": distance_type, + "using_ugc": using_ugc, + } + dataset.index_struct = json.dumps(index_struct_dict) + if using_ugc: + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + routing_value = class_prefix + else: + index_name = class_prefix + return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5a263d6e78c3bd..9b029ffc193cc0 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,8 +3,8 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pymilvus import MilvusClient, MilvusException # type: ignore +from pymilvus.milvus_client import IndexParams # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -54,14 +54,14 @@ def __init__(self, collection_name: str, config: MilvusConfig): self._client_config = config self._client = self._init_client(config) self._consistency_level = "Session" - self._fields = [] + self._fields: list[str] = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -161,8 +161,8 @@ def create_collection( return # Grab the existing collection if it exists if not self._client.has_collection(self._collection_name): - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore + from pymilvus.orm.types import infer_dtype_bydata # type: ignore # Determine embedding dim dim = len(embeddings[0]) @@ -217,10 +217,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return MilvusVector( collection_name=collection_name, config=MilvusConfig( - uri=dify_config.MILVUS_URI, - token=dify_config.MILVUS_TOKEN, - user=dify_config.MILVUS_USER, - password=dify_config.MILVUS_PASSWORD, - database=dify_config.MILVUS_DATABASE, + uri=dify_config.MILVUS_URI or "", + token=dify_config.MILVUS_TOKEN or "", + user=dify_config.MILVUS_USER or "", + password=dify_config.MILVUS_PASSWORD or "", + database=dify_config.MILVUS_DATABASE or "", ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b7b6b803ad20af..e63e1f522b3812 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -74,15 +74,16 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** columns = ["id", "text", "vector", "metadata"] values = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, - ) - values.append(str(row)) - ids.append(doc_id) + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + row = ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata) if doc.metadata else {}, + ) + values.append(str(row)) + ids.append(doc_id) sql = f""" INSERT INTO {self._config.database}.{self._collection_name} ({",".join(columns)}) VALUES {",".join(values)} diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index c44338d42a591a..957c799a60cbfe 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient +from pyobvector import VECTOR, ObVecClient # type: ignore from sqlalchemy import JSON, Column, String, func from sqlalchemy.dialects.mysql import LONGTEXT @@ -131,7 +131,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def text_exists(self, id: str) -> bool: cur = self._client.get(table_name=self._collection_name, id=id) - return cur.rowcount != 0 + return bool(cur.rowcount != 0) def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(table_name=self._collection_name, ids=ids) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7a976d7c3c8955..72a15022052f0a 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -66,7 +66,7 @@ def get_type(self) -> str: return VectorType.OPENSEARCH def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) @@ -244,7 +244,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( - host=dify_config.OPENSEARCH_HOST, + host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 71c58c9d0c37d5..dfff3563c3bb28 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -5,11 +5,9 @@ from contextlib import contextmanager from typing import Any -import jieba.posseg as pseg -import nltk +import jieba.posseg as pseg # type: ignore import numpy import oracledb -from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config @@ -90,12 +88,11 @@ def input_type_handler(self, cursor, value, arraysize): def numpy_converter_out(self, value): if value.typecode == "b": - dtype = numpy.int8 + return numpy.array(value, copy=False, dtype=numpy.int8) elif value.typecode == "f": - dtype = numpy.float32 + return numpy.array(value, copy=False, dtype=numpy.float32) else: - dtype = numpy.float64 - return numpy.array(value, copy=False, dtype=dtype) + return numpy.array(value, copy=False, dtype=numpy.float64) def output_type_handler(self, cursor, metadata): if metadata.type_code is oracledb.DB_TYPE_VECTOR: @@ -137,17 +134,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - # array.array("f", embeddings[i]), - numpy.array(embeddings[i]), + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + # array.array("f", embeddings[i]), + numpy.array(embeddings[i]), + ) ) - ) # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: cur.executemany( @@ -202,6 +200,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # lazy import + import nltk # type: ignore + from nltk.corpus import stopwords # type: ignore + top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -283,10 +285,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=dify_config.ORACLE_HOST, + host=dify_config.ORACLE_HOST or "localhost", port=dify_config.ORACLE_PORT, - user=dify_config.ORACLE_USER, - password=dify_config.ORACLE_PASSWORD, - database=dify_config.ORACLE_DATABASE, + user=dify_config.ORACLE_USER or "system", + password=dify_config.ORACLE_PASSWORD or "oracle", + database=dify_config.ORACLE_DATABASE or "orcl", ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 7cbbdcc81f6039..221bc68d68a6f7 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 from numpy import ndarray -from pgvecto_rs.sqlalchemy import VECTOR +from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator from sqlalchemy import Float, String, create_engine, insert, select, text from sqlalchemy import text as sql_text @@ -58,7 +58,7 @@ def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.commit() - self._fields = [] + self._fields: list[str] = [] class _Table(CollectionORM): __tablename__ = collection_name @@ -222,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=dify_config.PGVECTO_RS_HOST, - port=dify_config.PGVECTO_RS_PORT, - user=dify_config.PGVECTO_RS_USER, - password=dify_config.PGVECTO_RS_PASSWORD, - database=dify_config.PGVECTO_RS_DATABASE, + host=dify_config.PGVECTO_RS_HOST or "localhost", + port=dify_config.PGVECTO_RS_PORT or 5432, + user=dify_config.PGVECTO_RS_USER or "postgres", + password=dify_config.PGVECTO_RS_PASSWORD or "", + database=dify_config.PGVECTO_RS_DATABASE or "postgres", ), dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 40a9cdd136b404..271281ca7e939f 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -98,16 +98,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - embeddings[i], + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_values( cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values @@ -216,11 +217,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=dify_config.PGVECTOR_HOST, + host=dify_config.PGVECTOR_HOST or "localhost", port=dify_config.PGVECTOR_PORT, - user=dify_config.PGVECTOR_USER, - password=dify_config.PGVECTOR_PASSWORD, - database=dify_config.PGVECTOR_DATABASE, + user=dify_config.PGVECTOR_USER or "postgres", + password=dify_config.PGVECTOR_PASSWORD or "", + database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 3811458e02957c..6e94cb69db309d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -51,6 +51,8 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): + if not self.root_path: + raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) return {"path": path} @@ -149,9 +151,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] - added_ids = [] - for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + # Filter out None values from metadatas list to match expected type + filtered_metadatas = [m for m in metadatas if m is not None] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, filtered_metadatas, uuids, 64, self._group_id + ): self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) @@ -194,7 +199,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", # Ensure group_id is never None Field.GROUP_KEY.value, ), ) @@ -337,18 +342,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -432,9 +439,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=dify_config.QDRANT_URL, + endpoint=dify_config.QDRANT_URL or "", api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.config.root_path, + root_path=str(current_app.config.root_path), timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index f373dcfeabef92..a3a20448ff7a0a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from sqlalchemy import Column, Sequence, String, Table, create_engine, insert +from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session @@ -58,14 +58,14 @@ def __init__(self, collection_name: str, config: RelytConfig, group_id: str): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self.client = create_engine(self._url) - self._fields = [] + self._fields: list[str] = [] self._group_id = group_id def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = {} + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + index_params: dict[str, Any] = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -107,10 +107,10 @@ def create_collection(self, dimension: int): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from pgvecto_rs.sqlalchemy import VECTOR + from pgvecto_rs.sqlalchemy import VECTOR # type: ignore ids = [str(uuid.uuid1()) for _ in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] for metadata in metadatas: metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] @@ -242,10 +242,6 @@ def similarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: @@ -275,7 +271,7 @@ def similarity_search_with_score_by_vector( # Execute the query and fetch the results with self.client.connect() as conn: - results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + results = conn.execute(sql_text(sql_query), params).fetchall() documents_with_scores = [ ( @@ -307,11 +303,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return RelytVector( collection_name=collection_name, config=RelytConfig( - host=dify_config.RELYT_HOST, + host=dify_config.RELYT_HOST or "localhost", port=dify_config.RELYT_PORT, - user=dify_config.RELYT_USER, - password=dify_config.RELYT_PASSWORD, - database=dify_config.RELYT_DATABASE, + user=dify_config.RELYT_USER or "", + password=dify_config.RELYT_PASSWORD or "", + database=dify_config.RELYT_DATABASE or "default", ), group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index f971a9c5eb1696..c15f4b229f81c3 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -2,10 +2,10 @@ from typing import Any, Optional from pydantic import BaseModel -from tcvectordb import VectorDBClient -from tcvectordb.model import document, enum -from tcvectordb.model import index as vdb_index -from tcvectordb.model.document import Filter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model import document, enum # type: ignore +from tcvectordb.model import index as vdb_index # type: ignore +from tcvectordb.model.document import Filter # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -25,8 +25,8 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = (1,) - replicas: int = (2,) + shard: int = 1 + replicas: int = 2 def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -120,15 +120,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [doc.metadata for doc in documents] total_count = len(embeddings) docs = [] - for id in range(0, total_count): + for i in range(0, total_count): if metadatas is None: continue - metadata = json.dumps(metadatas[id]) + metadata = metadatas[i] or {} doc = document.Document( - id=metadatas[id]["doc_id"], - vector=embeddings[id], - text=texts[id], - metadata=metadata, + id=metadata.get("doc_id"), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadata), ) docs.append(doc) self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) @@ -159,8 +159,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def _get_search_res(self, res, score_threshold): - docs = [] + def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: + docs: list[Document] = [] if res is None or len(res) == 0: return docs @@ -193,7 +193,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TencentVector( collection_name=collection_name, config=TencentConfig( - url=dify_config.TENCENT_VECTOR_DB_URL, + url=dify_config.TENCENT_VECTOR_DB_URL or "", api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, username=dify_config.TENCENT_VECTOR_DB_USERNAME, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index cfd47aac5ba05b..19c5579a688f5a 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -54,7 +54,10 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): - path = os.path.join(self.root_path, path) + if self.root_path: + path = os.path.join(self.root_path, path) + else: + raise ValueError("root_path is required") return {"path": path} else: @@ -157,7 +160,7 @@ def create_collection(self, collection_name: str, vector_size: int): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] added_ids = [] for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): @@ -203,7 +206,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", Field.GROUP_KEY.value, ), ) @@ -334,18 +337,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -427,12 +432,12 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: new_cluster = TidbService.create_tidb_serverless_cluster( - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + dify_config.TIDB_PROJECT_ID or "", + dify_config.TIDB_API_URL or "", + dify_config.TIDB_IAM_API_URL or "", + dify_config.TIDB_PUBLIC_KEY or "", + dify_config.TIDB_PRIVATE_KEY or "", + dify_config.TIDB_REGION or "", ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], @@ -464,9 +469,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL, + endpoint=dify_config.TIDB_ON_QDRANT_URL or "", api_key=TIDB_ON_QDRANT_API_KEY, - root_path=config.root_path, + root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index a6f3ad7fef0c45..0a48c79511bf26 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -146,7 +146,7 @@ def batch_update_tidb_serverless_cluster_status( iam_url: str, public_key: str, private_key: str, - ) -> list[dict]: + ): """ Update the status of a new TiDB Serverless cluster. :param project_id: The project ID of the TiDB Cloud project (required). @@ -159,17 +159,15 @@ def batch_update_tidb_serverless_cluster_status( :return: The response from the API. """ - clusters = [] tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] - params = {"clusterIds": cluster_ids, "view": "FULL"} + params = {"clusterIds": cluster_ids, "view": "BASIC"} response = requests.get( f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) ) if response.status_code == 200: response_data = response.json() - cluster_infos = [] for item in response_data["clusters"]: state = item["state"] userPrefix = item["userPrefix"] @@ -236,16 +234,17 @@ def batch_create_tidb_serverless_cluster( cluster_infos = [] for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - password = redis_client.get(cache_key) - if not password: + cached_password = redis_client.get(cache_key) + if not cached_password: continue cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", - "password": password.decode("utf-8"), + "password": cached_password.decode("utf-8"), } cluster_infos.append(cluster_info) return cluster_infos else: response.raise_for_status() + return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 1147e35ce8fa55..be3a417390e802 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -37,8 +37,6 @@ def validate_config(cls, values: dict) -> dict: raise ValueError("config TIDB_VECTOR_PORT is required") if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values["password"]: - raise ValueError("config TIDB_VECTOR_PASSWORD is required") if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") if not values["program_name"]: @@ -51,7 +49,7 @@ def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: - from tidb_vector.sqlalchemy import VectorType + from tidb_vector.sqlalchemy import VectorType # type: ignore return Table( self._collection_name, @@ -243,11 +241,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=dify_config.TIDB_VECTOR_HOST, - port=dify_config.TIDB_VECTOR_PORT, - user=dify_config.TIDB_VECTOR_USER, - password=dify_config.TIDB_VECTOR_PASSWORD, - database=dify_config.TIDB_VECTOR_DATABASE, + host=dify_config.TIDB_VECTOR_HOST or "", + port=dify_config.TIDB_VECTOR_PORT or 0, + user=dify_config.TIDB_VECTOR_USER or "", + password=dify_config.TIDB_VECTOR_PASSWORD or "", + database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, ), ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 22e191340d3a47..edfce2edd896ee 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -51,15 +51,16 @@ def delete(self) -> None: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): - doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if text.metadata and "doc_id" in text.metadata: + doc_id = text.metadata["doc_id"] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 6d2e04fc020ab5..523fa80f124b0c 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -193,10 +193,13 @@ def _get_embeddings(self) -> Embeddings: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if doc_id: + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 4f927f28995613..9de8761a91ca68 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -2,7 +2,7 @@ from typing import Any from pydantic import BaseModel -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Data, DistanceType, Field, @@ -121,11 +121,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, page_content in enumerate(page_contents): metadata = {} if metadatas is not None: - for key, val in metadatas[i].items(): + for key, val in (metadatas[i] or {}).items(): metadata[key] = val + # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, vdb_Field.CONTENT_KEY.value: page_content, vdb_Field.METADATA_KEY.value: json.dumps(metadata), @@ -178,7 +179,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) - def _get_search_res(self, results, score_threshold): + def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: return [] @@ -191,7 +192,7 @@ def _get_search_res(self, results, score_threshold): metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 649cfbfea8253c..68d043a19f171f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional import requests -import weaviate +import weaviate # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -107,7 +107,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, text in enumerate(texts): data_properties = {Field.TEXT_KEY.value: text} if metadatas is not None: - for key, val in metadatas[i].items(): + # metadata maybe None + for key, val in (metadatas[i] or {}).items(): data_properties[key] = self._json_serializable(val) batch.add_data_object( @@ -208,10 +209,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata["score"] = score - docs.append(doc) + if doc.metadata is not None: + doc.metadata["score"] = score + docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -275,7 +277,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=dify_config.WEAVIATE_ENDPOINT, + endpoint=dify_config.WEAVIATE_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c7ecb8..35becaa0c7bea7 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -83,6 +83,9 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> if not isinstance(doc, Document): raise ValueError("doc must be a Document") + if doc.metadata is None: + raise ValueError("doc.metadata must be a dict") + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it @@ -179,10 +182,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: if document_segment is None: return None + data: Optional[str] = document_segment.index_node_hash + return data - return document_segment.index_node_hash - - def get_document_segment(self, doc_id: str) -> DocumentSegment: + def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index fc8e0440c332c3..a2c8737da79198 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Optional, cast +from typing import Any, Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] + text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) @@ -64,7 +64,13 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: for vector in embedding_result.embeddings: try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning(f"Normalized embedding is nan: {normalized_embedding}") + continue embedding_queue_embeddings.append(normalized_embedding) except IntegrityError: db.session.rollback() @@ -72,8 +78,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: logging.exception("Failed transform embedding") cache_embeddings = [] try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = n_embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( @@ -81,7 +87,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: hash=hash, provider_name=self._model_instance.provider, ) - embedding_cache.set_embedding(embedding) + embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() @@ -110,7 +116,10 @@ def embed_query(self, text: str) -> list[float]: ) embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3692b5d19dfb65..7c00c668dd49a3 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -14,7 +14,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Document = None + document: Optional[Document] = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index fc331657195454..c444105bb59443 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional +from typing import Optional, cast import pandas as pd from openpyxl import load_workbook @@ -47,7 +47,7 @@ def extract(self) -> list[Document]: for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): cell = sheet.cell( - row=index + 2, column=col_index + 1 + row=cast(int, index) + 2, column=col_index + 1 ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" @@ -60,8 +60,8 @@ def extract(self) -> list[Document]: elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") - for sheet_name in excel_file.sheet_names: - df = excel_file.parse(sheet_name=sheet_name) + for excel_sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) for _, row in df.iterrows(): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index a0b1aa4cefbd1f..a473b3dfa78a90 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -10,6 +10,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor @@ -66,9 +67,13 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = "." + re.search(r"\.(\w+)$", filename).group(1) - - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + match = re.search(r"\.(\w+)$", filename) + if match: + suffix = "." + match.group(1) + else: + suffix = "" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: @@ -89,21 +94,26 @@ def extract( if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: + assert extract_setting.upload_file is not None, "upload_file is required" upload_file: UploadFile = extract_setting.upload_file suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY + assert unstructured_api_url is not None, "unstructured_api_url is required" + assert unstructured_api_key is not None, "unstructured_api_key is required" + extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in {".md", ".markdown"}: + elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) if is_automatic @@ -141,7 +151,7 @@ def extract( extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in {".md", ".markdown"}: + elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) @@ -156,6 +166,7 @@ def extract( extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.NOTION.value: + assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, @@ -165,6 +176,7 @@ def extract( ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 17c2087a0ab575..8ae4579c7cf93f 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,5 +1,6 @@ import json import time +from typing import cast import requests @@ -20,9 +21,9 @@ def scrape_url(self, url, params=None) -> dict: json_data.update(params) response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: - response = response.json() - if response["success"] == True: - data = response["data"] + response_data = response.json() + if response_data["success"] == True: + data = response_data["data"] return { "title": data.get("metadata").get("title"), "description": data.get("metadata").get("description"), @@ -30,7 +31,7 @@ def scrape_url(self, url, params=None) -> dict: "markdown": data.get("markdown"), } else: - raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") @@ -46,9 +47,11 @@ def crawl_url(self, url, params=None) -> str: response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: job_id = response.json().get("jobId") - return job_id + return cast(str, job_id) else: self._handle_error(response, "start crawl job") + # FIXME: unreachable code for mypy + return "" # unreachable def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() @@ -64,9 +67,9 @@ def check_crawl_status(self, job_id) -> dict: for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - "title": item.get("metadata").get("title"), - "description": item.get("metadata").get("description"), - "source_url": item.get("metadata").get("sourceURL"), + "title": item.get("metadata", {}).get("title"), + "description": item.get("metadata", {}).get("description"), + "source_url": item.get("metadata", {}).get("sourceURL"), "markdown": item.get("markdown"), } url_data_list.append(url_data) @@ -92,6 +95,8 @@ def check_crawl_status(self, job_id) -> dict: else: self._handle_error(response, "check crawl status") + # FIXME: unreachable code for mypy + return {} # unreachable def _prepare_headers(self): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 560c2d1d84b04e..350b522347b09d 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -23,6 +23,7 @@ def extract(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: + text: str = "" with open(self._file_path, "rb") as fp: soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 87a4ce08bf3f89..fdc2e46d141d07 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, cast import requests @@ -78,6 +78,7 @@ def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" + assert self._notion_access_token is not None, "Notion access token is required" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ @@ -96,6 +97,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] for result in data["results"]: properties = result["properties"] data = {} + value: Any for property_name, property_value in properties.items(): type = property_value["type"] if type == "multi_select": @@ -130,6 +132,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) @@ -184,6 +187,7 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: def _read_block(self, block_id: str, num_tabs: int = 0) -> str: """Read a block.""" + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) @@ -242,6 +246,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: def _read_table_rows(self, block_id: str) -> str: """Read table rows.""" + assert self._notion_access_token is not None, "Notion access token is required" done = False result_lines_arr = [] start_cursor = None @@ -296,7 +301,7 @@ def _read_table_rows(self, block_id: str) -> str: result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: DocumentModel): + def update_last_edited_time(self, document_model: Optional[DocumentModel]): if not document_model: return @@ -309,6 +314,7 @@ def update_last_edited_time(self, document_model: DocumentModel): db.session.commit() def get_notion_last_edited_time(self) -> str: + assert self._notion_access_token is not None, "Notion access token is required" obj_id = self._notion_obj_id page_type = self._notion_page_type if page_type == "database": @@ -330,7 +336,7 @@ def get_notion_last_edited_time(self) -> str: ) data = res.json() - return data["last_edited_time"] + return cast(str, data["last_edited_time"]) @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: @@ -349,4 +355,4 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: f"and notion workspace {notion_workspace_id}" ) - return data_source_binding.access_token + return cast(str, data_source_binding.access_token) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 57cb9610ba267e..89a7061c26accc 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from collections.abc import Iterator -from typing import Optional +from typing import Optional, cast from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ def extract(self) -> list[Document]: plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode("utf-8") + text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -53,7 +53,7 @@ def load( def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 + import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index bd669bbad36873..9647dedfff8516 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,7 @@ import base64 import logging -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 35220b558afab9..80c29157aaf529 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -30,6 +30,9 @@ def extract(self) -> list[Document]: if self._api_url: from unstructured.partition.api import partition_via_api + if self._api_key is None: + raise ValueError("api_key is required") + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 0fdcd58b2e569b..e504d4bc23014c 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -27,9 +27,11 @@ def extract(self) -> list[Document]: elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: raise NotImplementedError("Unstructured API Url is not configured") - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number + if page is None: + continue text = element.text if page in text_by_page: text_by_page[page] += "\n" + text diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index ab41290fbc4537..cefe72b29052a1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -29,14 +29,15 @@ def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path) - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number text = element.text - if page in text_by_page: - text_by_page[page] += "\n" + text - else: - text_by_page[page] = text + if page is not None: + if page in text_by_page: + text_by_page[page] += "\n" + text + else: + text_by_page[page] = text combined_texts = list(text_by_page.values()) documents = [] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0c38a9c0762130..c3161bc812cb73 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -89,6 +89,8 @@ def _extract_images_from_docx(self, doc, image_folder): response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + if image_ext is None: + continue file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) @@ -97,6 +99,8 @@ def _extract_images_from_docx(self, doc, image_folder): continue else: image_ext = rel.target_ref.split(".")[-1] + if image_ext is None: + continue # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext @@ -226,6 +230,8 @@ def parse_docx(self, docx_path, image_folder): if x_child is None: continue if x.tag.endswith("instrText"): + if x.text is None: + continue for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index be857bd12215fd..7e5efdc66ed533 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -49,6 +49,7 @@ def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optiona """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule["mode"] == "custom": # The user-defined segmentation rule rules = processing_rule["rules"] diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index 9b855ece2c3512..c5ba6295f32f84 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -9,7 +9,7 @@ class IndexProcessorFactory: """IndexProcessorInit.""" - def __init__(self, index_type: str): + def __init__(self, index_type: str | None): self._index_type = index_type def init_index_processor(self) -> BaseIndexProcessor: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a631f953ce2191..c66fa54d503e9f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -27,12 +27,13 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule", {}), + embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {})) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) @@ -41,8 +42,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 320f0157a10049..20fd16e8f39b65 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -32,15 +32,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule") or {}, + embedding_model_instance=kwargs.get("embedding_model_instance"), ) # Split the text documents into nodes. - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {}) document.page_content = document_text # parse document to nodes @@ -50,8 +51,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -64,7 +66,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": kwargs.get("tenant_id"), "document_node": doc, "all_qa_documents": all_qa_documents, @@ -148,11 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6ae432a526b169..ac7a3f8bb857e4 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -30,7 +30,11 @@ def run( doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -54,7 +58,8 @@ def run( metadata=documents[result.index].metadata, provider=documents[result.index].provider, ) - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4719be012f99cc..cbc96037bf2cc0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,7 +39,7 @@ def run( unique_documents = [] doc_ids = set() for document in documents: - if document.metadata["doc_id"] not in doc_ids: + if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) @@ -56,10 +56,11 @@ def run( ) if score_threshold and score < score_threshold: continue - document.metadata["score"] = score - rerank_documents.append(document) + if document.metadata is not None: + document.metadata["score"] = score + rerank_documents.append(document) - rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True) + rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -76,8 +77,9 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -162,7 +164,7 @@ def _calculate_cosine( query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if "score" in document.metadata: + if document.metadata and "score" in document.metadata: query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 7a5bf39fa63f48..a265f36671b04b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,7 +1,7 @@ import math import threading from collections import Counter -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app @@ -34,7 +34,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -140,12 +140,12 @@ def retrieve( user_from, available_datasets, query, - retrieve_config.top_k, - retrieve_config.score_threshold, - retrieve_config.rerank_mode, + retrieve_config.top_k or 0, + retrieve_config.score_threshold or 0, + retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled, + retrieve_config.reranking_enabled or True, message_id, ) @@ -300,10 +300,11 @@ def single_retrieve( metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name results.append(document) else: retrieval_model_config = dataset.retrieval_model or default_retrieval_model @@ -325,7 +326,7 @@ def single_retrieve( score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") + score_threshold = retrieval_model_config.get("score_threshold", 0.0) with measure_time() as timer: results = RetrievalService.retrieve( @@ -358,14 +359,14 @@ def multiple_retrieve( score_threshold: float, reranking_mode: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + weights: Optional[dict[str, Any]] = None, reranking_enable: bool = True, message_id: Optional[str] = None, ): if not available_datasets: return [] threads = [] - all_documents = [] + all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets @@ -392,15 +393,18 @@ def multiple_retrieve( "The configured knowledge base list have different embedding model, please set reranking model." ) if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: - weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider - weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + if weights is not None: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[ + 0 + ].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset.id, "query": query, "top_k": top_k, @@ -439,21 +443,22 @@ def _on_retrieval_end( """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = ( + trace_manager: Optional[TraceQueueManager] = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: @@ -504,10 +509,11 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name all_documents.append(document) else: # get retrieval model , if the model is not setting , using default @@ -607,19 +613,20 @@ def to_dataset_retriever_tool( tools.append(tool) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - tool = DatasetMultiRetrieverTool.from_dataset( - dataset_ids=[dataset.id for dataset in available_datasets], - tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, - score_threshold=retrieve_config.score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), - ) + if retrieve_config.reranking_model is not None: + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=[dataset.id for dataset in available_datasets], + tenant_id=tenant_id, + top_k=retrieve_config.top_k or 2, + score_threshold=retrieve_config.score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source(), + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + ) - tools.append(tool) + tools.append(tool) return tools @@ -635,10 +642,11 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: query_keywords = keyword_table_handler.extract_keywords(query, None) documents_keywords = [] for document in documents: - # get the document keywords - document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -696,8 +704,9 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata["score"] = score - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + if document.metadata is not None: + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return documents[:top_k] if top_k else documents def calculate_vector_score( @@ -705,10 +714,12 @@ def calculate_vector_score( ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata["score"] >= score_threshold: + if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold): filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) + filter_documents = sorted( + filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True + ) return filter_documents[:top_k] if top_k else filter_documents diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 06147fe7b56544..b008d0df9c2f0e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -27,11 +28,14 @@ def invoke( SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=dataset_tools, - stream=False, - model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + ), ) if result.message.tool_calls: # get retrieval model config diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 68fab0c127a253..05e8d043dfe741 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,9 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -92,6 +92,7 @@ def _react_invoke( suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: + prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -149,12 +150,15 @@ def _invoke_llm( :param stop: stop :return: """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=completion_param, - stop=stop, - stream=True, - user=user_id, + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=completion_param, + stop=stop, + stream=True, + user=user_id, + ), ) # handle invoke result @@ -172,7 +176,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage :return: """ model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = "" usage = None for result in invoke_result: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 53032b34d570c7..3376bd7f75dd96 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", + allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(text: str) -> int: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7dd62f8de18a15..4bfa541fd454ad 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -92,7 +92,7 @@ def split_documents(self, documents: Iterable[Document]) -> list[Document]: texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) - metadatas.append(doc.metadata) + metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: @@ -143,7 +143,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") diff --git a/api/core/tools/README_JP.md b/api/core/tools/README_JA.md similarity index 100% rename from api/core/tools/README_JP.md rename to api/core/tools/README_JA.md diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ddb1481276df67..975c374cae8356 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,7 +14,7 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None - labels: list[str] = None + labels: list[str] | None = None UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 0c15b2a3711f11..7c365dc69d3b39 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel): # summary summary: Optional[str] = None # operation_id - operation_id: str = None + operation_id: str | None = None # parameters parameters: Optional[list[ToolParameter]] = None # author diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4fc383f91baeba..c87a90c03a6f7e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -243,19 +243,22 @@ def get_simple_instance( :param options: the options of the parameter """ # convert options to ToolParameterOption + # FIXME fix the type error if options: options = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) # type: ignore + for option in options # type: ignore ] return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), human_description=I18nObject(en_US="", zh_Hans=""), + placeholder=None, type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, - options=options, + options=options, # type: ignore ) @@ -331,7 +334,7 @@ def to_dict(self) -> dict: "default": self.default, "options": self.options, "help": self.help.to_dict() if self.help else None, - "label": self.label.to_dict(), + "label": self.label.to_dict() if self.label else None, "url": self.url, "placeholder": self.placeholder.to_dict() if self.placeholder else None, } @@ -374,7 +377,10 @@ def __init__(self, **data: Any): pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) - def dict(self) -> dict: + def dict(self) -> dict: # type: ignore + """ + FIXME: just ignore the type check for now + """ return { "conversation_id": self.conversation_id, "user_id": self.user_id, diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index d99314e33a3204..f451edbf2ee969 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,9 +1,14 @@ +from typing import Optional + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, + ToolDescription, + ToolIdentity, ToolProviderCredentials, + ToolProviderIdentity, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -64,21 +69,18 @@ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "Ap pass else: raise ValueError(f"invalid auth type {auth_type}") - - user_name = db_provider.user.name if db_provider.user_id else "" - + user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else "" return ApiToolProviderController( - **{ - "identity": { - "author": user_name, - "name": db_provider.name, - "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": credentials_schema, - "provider_id": db_provider.id or "", - } + identity=ToolProviderIdentity( + author=user_name, + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=credentials_schema, + provider_id=db_provider.id or "", + tools=None, ) @property @@ -93,24 +95,22 @@ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: :return: the tool """ return ApiTool( - **{ - "api_bundle": tool_bundle, - "identity": { - "author": tool_bundle.author, - "name": tool_bundle.operation_id, - "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, - "icon": self.identity.icon, - "provider": self.provider_id, - }, - "description": { - "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, - "llm": tool_bundle.summary or "", - }, - "parameters": tool_bundle.parameters or [], - } + api_bundle=tool_bundle, + identity=ToolIdentity( + author=tool_bundle.author, + name=tool_bundle.operation_id or "", + label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id), + icon=self.identity.icon if self.identity else None, + provider=self.provider_id, + ), + description=ToolDescription( + human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""), + llm=tool_bundle.summary or "", + ), + parameters=tool_bundle.parameters or [], ) - def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]: """ load bundled tools @@ -121,7 +121,7 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: return self.tools - def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -131,6 +131,8 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ if self.tools is not None: return self.tools + if self.identity is None: + return None tools: list[Tool] = [] @@ -151,7 +153,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: self.tools = tools return tools - def get_tool(self, tool_name: str) -> ApiTool: + def get_tool(self, tool_name: str) -> Tool: """ get tool by name @@ -161,7 +163,9 @@ def get_tool(self, tool_name: str) -> ApiTool: if self.tools is None: self.get_tools() - for tool in self.tools: + for tool in self.tools or []: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 09f328cd1fe65f..fc29920acd40dc 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -1,9 +1,10 @@ import logging -from typing import Any +from typing import Any, Optional from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.api_tool import ApiTool from core.tools.tool.tool import Tool from extensions.ext_database import db from models.model import App, AppModelConfig @@ -20,10 +21,10 @@ def provider_type(self) -> ToolProviderType: def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: pass - def get_tools(self, user_id: str) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]: db_tools: list[PublishedAppTool] = ( db.session.query(PublishedAppTool) .filter( @@ -38,7 +39,7 @@ def get_tools(self, user_id: str) -> list[Tool]: tools: list[Tool] = [] for db_tool in db_tools: - tool = { + tool: dict[str, Any] = { "identity": { "author": db_tool.author, "name": db_tool.tool_name, @@ -52,7 +53,7 @@ def get_tools(self, user_id: str) -> list[Tool]: "parameters": [], } # get app from db - app: App = db_tool.app + app: Optional[App] = db_tool.app if not app: logger.error(f"app {db_tool.app_id} not found") @@ -62,7 +63,7 @@ def get_tools(self, user_id: str) -> list[Tool]: user_input_form_list = app_model_config.user_input_form_list for input_form in user_input_form_list: # get type - form_type = input_form.keys()[0] + form_type = list(input_form.keys())[0] default = input_form[form_type]["default"] required = input_form[form_type]["required"] label = input_form[form_type]["label"] @@ -79,6 +80,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.STRING, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif form_type == "select": @@ -92,6 +94,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.SELECT, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), options=[ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options @@ -99,5 +102,5 @@ def get_tools(self, user_id: str) -> list[Tool]: ) ) - tools.append(Tool(**tool)) + tools.append(ApiTool(**tool)) return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 5c10f72fdaed01..99a062f8c366aa 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -5,7 +5,7 @@ class BuiltinToolProviderSort: - _position = {} + _position: dict[str, int] = {} @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 38123f125ae974..cf10f5d2556edd 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any +from typing import Any, Union from httpx import get, post from requests import get as requests_get @@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter: """ _api_base_url = URL("https://co.aippt.cn/api") - _api_token_cache = {} - _style_cache = {} + _api_token_cache: dict[str, dict[str, Union[str, float]]] = {} + _style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {} - _api_token_cache_lock = Lock() - _style_cache_lock = Lock() + _api_token_cache_lock: Lock = Lock() + _style_cache_lock: Lock = Lock() - _task = {} + _task: dict[str, Any] = {} _task_type_map = { "auto": 1, "markdown": 7, } - _tool: BuiltinTool + _tool: BuiltinTool | None - def __init__(self, tool: BuiltinTool = None): + def __init__(self, tool: BuiltinTool | None = None): self._tool = tool - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the AIPPT generate tool with the given user ID and tool parameters. @@ -68,8 +70,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ) # get suit - color: str = tool_parameters.get("color") - style: str = tool_parameters.get("style") + color: str = tool_parameters.get("color", "") + style: str = tool_parameters.get("style", "") if color == "__default__": color_id = "" @@ -226,7 +228,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str: return "" - def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: + def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]: """ Generate a ppt @@ -362,7 +364,9 @@ def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: ).decode("utf-8") @classmethod - def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: + def _get_styles( + cls, credentials: dict[str, str], user_id: str + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles """ @@ -415,7 +419,7 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di return colors, styles - def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles @@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool): def __init__(self, **kwargs: Any): super().__init__(**kwargs) - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) def get_runtime_parameters(self) -> list[ToolParameter]: diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 2d65ba2d6f4389..8bd16050ecf0a6 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -1,7 +1,7 @@ import logging from typing import Any, Optional -import arxiv +import arxiv # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py index f83a64d041faab..8a33ac405bd4c3 100644 --- a/api/core/tools/provider/builtin/audio/tools/tts.py +++ b/api/core/tools/provider/builtin/audio/tools/tts.py @@ -11,19 +11,21 @@ class TTSTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: - provider, model = tool_parameters.get("model").split("#") - voice = tool_parameters.get(f"voice#{provider}#{model}") + provider, model = tool_parameters.get("model", "").split("#") + voice = tool_parameters.get(f"voice#{provider}#{model}", "") model_manager = ModelManager() + if not self.runtime: + raise ValueError("Runtime is required") model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, ) tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), + content_text=tool_parameters.get("text", ""), user=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", voice=voice, ) buffer = io.BytesIO() @@ -41,8 +43,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInv ] def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + tid: str = self.runtime.tenant_id or "" + models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts") items = [] for provider_model in models: provider = provider_model.provider @@ -62,6 +67,8 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ToolParameter( name=f"voice#{provider}#{model}", label=I18nObject(en_US=f"Voice of {model}({provider})"), + human_description=I18nObject(en_US=f"Select a voice for {model} model"), + placeholder=I18nObject(en_US="Select a voice"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ @@ -83,6 +90,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=True, + placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"), options=options, ), ) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index a04f5c0fe9f1af..b224ff5258c879 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -2,8 +2,8 @@ import logging from typing import Any, Union -import boto3 -from botocore.exceptions import BotoCoreError +import boto3 # type: ignore +from botocore.exceptions import BotoCoreError # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py new file mode 100644 index 00000000000000..050b468b740c27 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -0,0 +1,115 @@ +import json +import operator +from typing import Any, Optional, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class BedrockRetrieveTool(BuiltinTool): + bedrock_client: Any = None + knowledge_base_id: str = None + topk: int = None + + def _bedrock_retrieve( + self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + ): + try: + retrieval_query = {"text": query_input} + + retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + + # 如果有元数据过滤条件,则添加到检索配置中 + if metadata_filter: + retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter + + response = self.bedrock_client.retrieve( + knowledgeBaseId=knowledge_base_id, + retrievalQuery=retrieval_query, + retrievalConfiguration=retrieval_configuration, + ) + + results = [] + for result in response.get("retrievalResults", []): + results.append( + { + "content": result.get("content", {}).get("text", ""), + "score": result.get("score", 0.0), + "metadata": result.get("metadata", {}), + } + ) + + return results + except Exception as e: + raise Exception(f"Error retrieving from knowledge base: {str(e)}") + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.bedrock_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) + else: + self.bedrock_client = boto3.client("bedrock-agent-runtime") + + line = 1 + if not self.knowledge_base_id: + self.knowledge_base_id = tool_parameters.get("knowledge_base_id") + if not self.knowledge_base_id: + return self.create_text_message("Please provide knowledge_base_id") + + line = 2 + if not self.topk: + self.topk = tool_parameters.get("topk", 5) + + line = 3 + query = tool_parameters.get("query", "") + if not query: + return self.create_text_message("Please input query") + + # 获取元数据过滤条件(如果存在) + metadata_filter_str = tool_parameters.get("metadata_filter") + metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + + line = 4 + retrieved_docs = self._bedrock_retrieve( + query_input=query, + knowledge_base_id=self.knowledge_base_id, + num_results=self.topk, + metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + ) + + line = 5 + # Sort results by score in descending order + sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) + + line = 6 + return [self.create_json_message(res) for res in sorted_docs] + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}, line : {line}") + + def validate_parameters(self, parameters: dict[str, Any]) -> None: + """ + Validate the parameters + """ + if not parameters.get("knowledge_base_id"): + raise ValueError("knowledge_base_id is required") + + if not parameters.get("query"): + raise ValueError("query is required") + + # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + metadata_filter_str = parameters.get("metadata_filter") + if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): + raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml new file mode 100644 index 00000000000000..9e51d52def4037 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -0,0 +1,87 @@ +identity: + name: bedrock_retrieve + author: AWS + label: + en_US: Bedrock Retrieve + zh_Hans: Bedrock检索 + pt_BR: Bedrock Retrieve + icon: icon.svg + +description: + human: + en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 + pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. + llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + +parameters: + - name: knowledge_base_id + type: string + required: true + label: + en_US: Bedrock Knowledge Base ID + zh_Hans: Bedrock知识库ID + pt_BR: Bedrock Knowledge Base ID + human_description: + en_US: ID of the Bedrock Knowledge Base to retrieve from + zh_Hans: 用于检索的Bedrock知识库ID + pt_BR: ID of the Bedrock Knowledge Base to retrieve from + llm_description: ID of the Bedrock Knowledge Base to retrieve from + form: form + + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: The search query to retrieve relevant information + zh_Hans: 用于检索相关信息的查询语句 + pt_BR: The search query to retrieve relevant information + llm_description: The search query to retrieve relevant information + form: llm + + - name: topk + type: number + required: false + form: form + label: + en_US: Limit for results count + zh_Hans: 返回结果数量限制 + pt_BR: Limit for results count + human_description: + en_US: Maximum number of results to return + zh_Hans: 最大返回结果数量 + pt_BR: Maximum number of results to return + min: 1 + max: 10 + default: 5 + + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + pt_BR: AWS Region + human_description: + en_US: AWS region where the Bedrock Knowledge Base is located + zh_Hans: Bedrock知识库所在的AWS区域 + pt_BR: AWS region where the Bedrock Knowledge Base is located + llm_description: AWS region where the Bedrock Knowledge Base is located + form: form + + - name: metadata_filter + type: string + required: false + label: + en_US: Metadata Filter + zh_Hans: 元数据过滤器 + pt_BR: Metadata Filter + human_description: + en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' + pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 989608122185c8..b6d16d2759c30e 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index f43f3b6fe05694..01bc596346c231 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -2,7 +2,7 @@ import logging from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.py b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py new file mode 100644 index 00000000000000..954dbe35a4a784 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py @@ -0,0 +1,357 @@ +import base64 +import json +import logging +import re +from datetime import datetime +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NovaCanvasTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Canvas model for image generation + """ + # Get common parameters + prompt = tool_parameters.get("prompt", "") + image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip() + if not prompt: + return self.create_text_message("Please provide a text prompt for image generation.") + if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3": + return self.create_text_message("Please provide an valid S3 URI for image output.") + + task_type = tool_parameters.get("task_type", "TEXT_IMAGE") + aws_region = tool_parameters.get("aws_region", "us-east-1") + + # Get common image generation config parameters + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + cfg_scale = tool_parameters.get("cfg_scale", 8.0) + negative_prompt = tool_parameters.get("negative_prompt", "") + seed = tool_parameters.get("seed", 0) + quality = tool_parameters.get("quality", "standard") + + # Handle S3 image if provided + image_input_s3uri = tool_parameters.get("image_input_s3uri", "") + if task_type != "TEXT_IMAGE": + if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3": + return self.create_text_message("Please provide a valid S3 URI for image to image generation.") + + # Parse S3 URI + parsed_uri = urlparse(image_input_s3uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + # Initialize S3 client and download image + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + image_data = response["Body"].read() + + # Base64 encode the image + input_image = base64.b64encode(image_data).decode("utf-8") + + try: + # Initialize Bedrock client + bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region) + + # Base image generation config + image_generation_config = { + "width": width, + "height": height, + "cfgScale": cfg_scale, + "seed": seed, + "numberOfImages": 1, + "quality": quality, + } + + # Prepare request body based on task type + body = {"imageGenerationConfig": image_generation_config} + + if task_type == "TEXT_IMAGE": + body["taskType"] = "TEXT_IMAGE" + body["textToImageParams"] = {"text": prompt} + if negative_prompt: + body["textToImageParams"]["negativeText"] = negative_prompt + + elif task_type == "COLOR_GUIDED_GENERATION": + colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680") + if not self._validate_color_string(colors): + return self.create_text_message("Please provide valid colors in hexadecimal format.") + + body["taskType"] = "COLOR_GUIDED_GENERATION" + body["colorGuidedGenerationParams"] = { + "colors": colors.split("-"), + "referenceImage": input_image, + "text": prompt, + } + if negative_prompt: + body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt + + elif task_type == "IMAGE_VARIATION": + similarity_strength = tool_parameters.get("similarity_strength", 0.5) + + body["taskType"] = "IMAGE_VARIATION" + body["imageVariationParams"] = { + "images": [input_image], + "similarityStrength": similarity_strength, + "text": prompt, + } + if negative_prompt: + body["imageVariationParams"]["negativeText"] = negative_prompt + + elif task_type == "INPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image inpainting.") + + body["taskType"] = "INPAINTING" + body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt} + if negative_prompt: + body["inPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "OUTPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image outpainting.") + outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT") + + body["taskType"] = "OUTPAINTING" + body["outPaintingParams"] = { + "image": input_image, + "maskPrompt": mask_prompt, + "outPaintingMode": outpainting_mode, + "text": prompt, + } + if negative_prompt: + body["outPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "BACKGROUND_REMOVAL": + body["taskType"] = "BACKGROUND_REMOVAL" + body["backgroundRemovalParams"] = {"image": input_image} + + else: + return self.create_text_message(f"Unsupported task type: {task_type}") + + # Call Nova Canvas model + response = bedrock.invoke_model( + body=json.dumps(body), + modelId="amazon.nova-canvas-v1:0", + accept="application/json", + contentType="application/json", + ) + + # Process response + response_body = json.loads(response.get("body").read()) + if response_body.get("error"): + raise Exception(f"Error in model response: {response_body.get('error')}") + base64_image = response_body.get("images")[0] + + # Upload to S3 if image_output_s3uri is provided + try: + # Parse S3 URI for output + parsed_uri = urlparse(image_output_s3uri) + output_bucket = parsed_uri.netloc + output_base_path = parsed_uri.path.lstrip("/") + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_key = f"{output_base_path}/canvas-output-{timestamp}.png" + + # Initialize S3 client if not already done + s3_client = boto3.client("s3", region_name=aws_region) + + # Decode base64 image and upload to S3 + image_data = base64.b64decode(base64_image) + s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png") + logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}") + except Exception as e: + logger.exception("Failed to upload image to S3") + # Return image + return [ + self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"), + self.create_blob_message( + blob=base64.b64decode(base64_image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ), + ] + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def _validate_color_string(self, color_string) -> bool: + color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$" + + if re.match(color_pattern, color_string): + return True + return False + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the image you want to generate or modify", + zh_Hans="您想要生成或修改的图像的文本描述", + ), + llm_description="Describe the image you want to generate or how you want to modify the input image", + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"), + ), + ToolParameter( + name="image_output_s3uri", + label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI" + ), + ), + ToolParameter( + name="width", + label=I18nObject(en_US="Width", zh_Hans="宽度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"), + ), + ToolParameter( + name="height", + label=I18nObject(en_US="Height", zh_Hans="高度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"), + ), + ToolParameter( + name="cfg_scale", + label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=8.0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词" + ), + ), + ToolParameter( + name="negative_prompt", + label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容" + ), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="us-east-1", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="task_type", + label=I18nObject(en_US="Task Type", zh_Hans="任务类型"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="TEXT_IMAGE", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"), + ), + ToolParameter( + name="quality", + label=I18nObject(en_US="Quality", zh_Hans="质量"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="standard", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)" + ), + ), + ToolParameter( + name="colors", + label=I18nObject(en_US="Colors", zh_Hans="颜色"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680", + zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680", + ), + ), + ToolParameter( + name="similarity_strength", + label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0.5, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How similar the generated image should be to the input image (0.0 to 1.0)", + zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)", + ), + ), + ToolParameter( + name="mask_prompt", + label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description to generate mask for inpainting/outpainting", + zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述", + ), + ), + ToolParameter( + name="outpainting_mode", + label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="DEFAULT", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Mode for outpainting (DEFAULT or other supported modes)", + zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml new file mode 100644 index 00000000000000..a72fd9c8efcce1 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml @@ -0,0 +1,175 @@ +identity: + name: nova_canvas + author: AWS + label: + en_US: AWS Bedrock Nova Canvas + zh_Hans: AWS Bedrock Nova Canvas + icon: icon.svg +description: + human: + en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。 + llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. +parameters: + - name: task_type + type: string + required: false + default: TEXT_IMAGE + label: + en_US: Task Type + zh_Hans: 任务类型 + human_description: + en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL) + zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除) + form: llm + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the image you want to generate or modify + zh_Hans: 您想要生成或修改的图像的文本描述 + llm_description: Describe the image you want to generate or how you want to modify the input image + form: llm + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input image s3 uri + zh_Hans: 输入图片的s3 uri + human_description: + en_US: The input image to modify (required for all modes except TEXT_IMAGE) + zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要) + llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE. + form: llm + - name: image_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png + zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传 + llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename. + form: form + - name: negative_prompt + type: string + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + human_description: + en_US: Things you don't want in the generated image + zh_Hans: 您不想在生成的图像中出现的内容 + form: llm + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: 宽度 + human_description: + en_US: Width of the generated image + zh_Hans: 生成图像的宽度 + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: 高度 + human_description: + en_US: Height of the generated image + zh_Hans: 生成图像的高度 + form: form + default: 1024 + - name: cfg_scale + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG比例 + human_description: + en_US: How strongly the image should conform to the prompt + zh_Hans: 图像应该多大程度上符合提示词 + form: form + default: 8.0 + - name: seed + type: number + required: false + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for image generation + zh_Hans: 图像生成的随机种子 + form: form + default: 0 + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + - name: quality + type: string + required: false + default: standard + label: + en_US: Quality + zh_Hans: 质量 + human_description: + en_US: Quality of the generated image (standard or premium) + zh_Hans: 生成图像的质量(标准或高级) + form: form + - name: colors + type: string + required: false + label: + en_US: Colors + zh_Hans: 颜色 + human_description: + en_US: List of colors for color-guided generation + zh_Hans: 颜色引导生成的颜色列表 + form: form + - name: similarity_strength + type: number + required: false + default: 0.5 + label: + en_US: Similarity Strength + zh_Hans: 相似度强度 + human_description: + en_US: How similar the generated image should be to the input image (0.0 to 1.0) + zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0) + form: form + - name: mask_prompt + type: string + required: false + label: + en_US: Mask Prompt + zh_Hans: 蒙版提示词 + human_description: + en_US: Text description to generate mask for inpainting/outpainting + zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述 + form: llm + - name: outpainting_mode + type: string + required: false + default: DEFAULT + label: + en_US: Outpainting Mode + zh_Hans: 外补绘制模式 + human_description: + en_US: Mode for outpainting (DEFAULT or other supported modes) + zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式) + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.py b/api/core/tools/provider/builtin/aws/tools/nova_reel.py new file mode 100644 index 00000000000000..bfd3d302b22d48 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.py @@ -0,0 +1,371 @@ +import base64 +import logging +import time +from io import BytesIO +from typing import Any, Optional, Union +from urllib.parse import urlparse + +import boto3 +from botocore.exceptions import ClientError +from PIL import Image + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +NOVA_REEL_DEFAULT_REGION = "us-east-1" +NOVA_REEL_DEFAULT_DIMENSION = "1280x720" +NOVA_REEL_DEFAULT_FPS = 24 +NOVA_REEL_DEFAULT_DURATION = 6 +NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0" +NOVA_REEL_STATUS_CHECK_INTERVAL = 5 + +# Image requirements +NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280 +NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720 +NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB" + + +class NovaReelTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Reel model for video generation. + + Args: + user_id: The ID of the user making the request + tool_parameters: Dictionary containing the tool parameters + + Returns: + ToolInvokeMessage containing either the video content or status information + """ + try: + # Validate and extract parameters + params = self._validate_and_extract_parameters(tool_parameters) + if isinstance(params, ToolInvokeMessage): + return params + + # Initialize AWS clients + bedrock, s3_client = self._initialize_aws_clients(params["aws_region"]) + + # Prepare model input + model_input = self._prepare_model_input(params, s3_client) + if isinstance(model_input, ToolInvokeMessage): + return model_input + + # Start video generation + invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"]) + invocation_arn = invocation["invocationArn"] + + # Handle async/sync mode + return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"]) + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_message = e.response.get("Error", {}).get("Message", str(e)) + logger.exception(f"AWS API error: {error_code} - {error_message}") + return self.create_text_message(f"AWS service error: {error_code} - {error_message}") + except Exception as e: + logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to generate video: {str(e)}") + + def _validate_and_extract_parameters( + self, tool_parameters: dict[str, Any] + ) -> Union[dict[str, Any], ToolInvokeMessage]: + """Validate and extract parameters from the input dictionary.""" + prompt = tool_parameters.get("prompt", "") + video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip() + + # Validate required parameters + if not prompt: + return self.create_text_message("Please provide a text prompt for video generation.") + if not video_output_s3uri: + return self.create_text_message("Please provide an S3 URI for video output.") + + # Validate S3 URI format + if not video_output_s3uri.startswith("s3://"): + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + # Ensure S3 URI ends with '/' + video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/" + + return { + "prompt": prompt, + "video_output_s3uri": video_output_s3uri, + "image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(), + "aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION), + "dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION), + "seed": int(tool_parameters.get("seed", 0)), + "fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)), + "duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)), + "async_mode": bool(tool_parameters.get("async", True)), + } + + def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]: + """Initialize AWS Bedrock and S3 clients.""" + bedrock = boto3.client(service_name="bedrock-runtime", region_name=region) + s3_client = boto3.client("s3", region_name=region) + return bedrock, s3_client + + def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]: + """Prepare the input for the Nova Reel model.""" + model_input = { + "taskType": "TEXT_VIDEO", + "textToVideoParams": {"text": params["prompt"]}, + "videoGenerationConfig": { + "durationSeconds": params["duration"], + "fps": params["fps"], + "dimension": params["dimension"], + "seed": params["seed"], + }, + } + + # Add image if provided + if params["image_input_s3uri"]: + try: + image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"]) + if not image_data: + return self.create_text_message("Failed to retrieve image from S3") + + # Process and validate image + processed_image = self._process_and_validate_image(image_data) + if isinstance(processed_image, ToolInvokeMessage): + return processed_image + + # Convert processed image to base64 + img_buffer = BytesIO() + processed_image.save(img_buffer, format="PNG") + img_buffer.seek(0) + input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + model_input["textToVideoParams"]["images"] = [ + {"format": "png", "source": {"bytes": input_image_base64}} + ] + except Exception as e: + logger.error(f"Error processing input image: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to process input image: {str(e)}") + + return model_input + + def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]: + """ + Process and validate the input image according to Nova Reel requirements. + + Requirements: + - Must be 1280x720 pixels + - Must be RGB format (8 bits per channel) + - If PNG, alpha channel must not have transparent/translucent pixels + """ + try: + # Open image + img = Image.open(BytesIO(image_data)) + + # Convert RGBA to RGB if needed, ensuring no transparency + if img.mode == "RGBA": + # Check for transparency + if img.getchannel("A").getextrema()[0] < 255: + return self.create_text_message( + "PNG image contains transparent or translucent pixels, which is not supported. " + "Please provide an image without transparency." + ) + # Convert to RGB + img = img.convert("RGB") + elif img.mode != "RGB": + # Convert any other mode to RGB + img = img.convert("RGB") + + # Validate/adjust dimensions + if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT): + logger.warning( + f"Image dimensions {img.size} do not match required dimensions " + f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..." + ) + img = img.resize( + (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS + ) + + # Validate bit depth + if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE: + return self.create_text_message( + f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel" + ) + + return img + + except Exception as e: + logger.error(f"Error processing image: {str(e)}", exc_info=True) + return self.create_text_message( + "Failed to process image. Please ensure the image is a valid JPEG or PNG file." + ) + + def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]: + """Download and return image data from S3.""" + parsed_uri = urlparse(s3_uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + response = s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]: + """Start the async video generation process.""" + return bedrock.start_async_invoke( + modelId=NOVA_REEL_MODEL_ID, + modelInput=model_input, + outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}}, + ) + + def _handle_generation_mode( + self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool + ) -> ToolInvokeMessage: + """Handle async or sync video generation mode.""" + invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + video_uri = f"{video_path}/output.mp4" + + if async_mode: + return self.create_text_message( + f"Video generation started.\nInvocation ARN: {invocation_arn}\n" + f"Video will be available at: {video_uri}" + ) + + return self._wait_for_completion(bedrock, s3_client, invocation_arn) + + def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage: + """Wait for video generation completion and handle the result.""" + while True: + status_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + status = status_response["status"] + video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + + if status == "Completed": + return self._handle_completed_video(s3_client, video_path) + elif status == "Failed": + failure_message = status_response.get("failureMessage", "Unknown error") + return self.create_text_message(f"Video generation failed.\nError: {failure_message}") + elif status == "InProgress": + time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL) + else: + return self.create_text_message(f"Unexpected status: {status}") + + def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage: + """Handle completed video generation and return the result.""" + parsed_uri = urlparse(video_path) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + "/output.mp4" + + try: + response = s3_client.get_object(Bucket=bucket, Key=key) + video_content = response["Body"].read() + return [ + self.create_text_message(f"Video is available at: {video_path}/output.mp4"), + self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"), + ] + except Exception as e: + logger.error(f"Error downloading video: {str(e)}", exc_info=True) + return self.create_text_message( + f"Video generation completed but failed to download video: {str(e)}\n" + f"Video is available at: s3://{bucket}/{key}" + ) + + def get_runtime_parameters(self) -> list[ToolParameter]: + """Define the tool's runtime parameters.""" + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述" + ), + llm_description="Describe the video you want to generate", + ), + ToolParameter( + name="video_output_s3uri", + label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI" + ), + ), + ToolParameter( + name="dimension", + label=I18nObject(en_US="Dimension", zh_Hans="尺寸"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_DIMENSION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"), + ), + ToolParameter( + name="duration", + label=I18nObject(en_US="Duration", zh_Hans="时长"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_DURATION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"), + ), + ToolParameter( + name="fps", + label=I18nObject(en_US="FPS", zh_Hans="帧率"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_FPS, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数" + ), + ), + ToolParameter( + name="async", + label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"), + type=ToolParameter.ToolParameterType.BOOLEAN, + required=False, + default=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)", + zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)", + ), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_REGION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame", + zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml new file mode 100644 index 00000000000000..16df5ba5c9d1e3 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml @@ -0,0 +1,124 @@ +identity: + name: nova_reel + author: AWS + label: + en_US: AWS Bedrock Nova Reel + zh_Hans: AWS Bedrock Nova Reel + icon: icon.svg +description: + human: + en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution. + +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the video you want to generate + zh_Hans: 您想要生成的视频的文本描述 + llm_description: Describe the video you want to generate + form: llm + + - name: video_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: S3 URI where the generated video will be stored + zh_Hans: 生成的视频将存储的S3 URI + form: form + + - name: dimension + type: string + required: false + default: 1280x720 + label: + en_US: Dimension + zh_Hans: 尺寸 + human_description: + en_US: Video dimensions (width x height) + zh_Hans: 视频尺寸(宽 x 高) + form: form + + - name: duration + type: number + required: false + default: 6 + label: + en_US: Duration + zh_Hans: 时长 + human_description: + en_US: Video duration in seconds + zh_Hans: 视频时长(秒) + form: form + + - name: seed + type: number + required: false + default: 0 + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for video generation + zh_Hans: 视频生成的随机种子 + form: form + + - name: fps + type: number + required: false + default: 24 + label: + en_US: FPS + zh_Hans: 帧率 + human_description: + en_US: Frames per second for the generated video + zh_Hans: 生成视频的每秒帧数 + form: form + + - name: async + type: boolean + required: false + default: true + label: + en_US: Async Mode + zh_Hans: 异步模式 + human_description: + en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion) + zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成) + form: llm + + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input Image S3 URI + zh_Hans: 输入图像S3 URI + human_description: + en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame + zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI + form: llm + +development: + dependencies: + - boto3 + - pillow diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.py b/api/core/tools/provider/builtin/aws/tools/s3_operator.py new file mode 100644 index 00000000000000..e4026b07a87310 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.py @@ -0,0 +1,80 @@ +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class S3Operator(BuiltinTool): + s3_client: Any = None + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + # Initialize S3 client if not already done + if not self.s3_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.s3_client = boto3.client("s3") + + # Parse S3 URI + s3_uri = tool_parameters.get("s3_uri") + if not s3_uri: + return self.create_text_message("s3_uri parameter is required") + + parsed_uri = urlparse(s3_uri) + if parsed_uri.scheme != "s3": + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + bucket = parsed_uri.netloc + # Remove leading slash from key + key = parsed_uri.path.lstrip("/") + + operation_type = tool_parameters.get("operation_type", "read") + generate_presign_url = tool_parameters.get("generate_presign_url", False) + presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour + + if operation_type == "write": + text_content = tool_parameters.get("text_content") + if not text_content: + return self.create_text_message("text_content parameter is required for write operation") + + # Write content to S3 + self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8")) + result = f"s3://{bucket}/{key}" + + # Generate presigned URL for the written object if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + else: # read operation + # Get object from S3 + response = self.s3_client.get_object(Bucket=bucket, Key=key) + result = response["Body"].read().decode("utf-8") + + # Generate presigned URL if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + return self.create_text_message(text=result) + + except self.s3_client.exceptions.NoSuchBucket: + return self.create_text_message(f"Bucket '{bucket}' does not exist") + except self.s3_client.exceptions.NoSuchKey: + return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'") + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml new file mode 100644 index 00000000000000..642fc2966e9b6d --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml @@ -0,0 +1,98 @@ +identity: + name: s3_operator + author: AWS + label: + en_US: AWS S3 Operator + zh_Hans: AWS S3 读写器 + pt_BR: AWS S3 Operator + icon: icon.svg +description: + human: + en_US: AWS S3 Writer and Reader + zh_Hans: 读写S3 bucket中的文件 + pt_BR: AWS S3 Writer and Reader + llm: AWS S3 Writer and Reader +parameters: + - name: text_content + type: string + required: false + label: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + human_description: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + llm_description: The text to write + form: llm + - name: s3_uri + type: string + required: true + label: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + human_description: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + llm_description: s3 uri + form: llm + - name: aws_region + type: string + required: true + label: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + human_description: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + llm_description: region of bucket + form: form + - name: operation_type + type: select + required: true + label: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + human_description: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + default: read + options: + - value: read + label: + en_US: read + zh_Hans: 读 + - value: write + label: + en_US: write + zh_Hans: 写 + form: form + - name: generate_presign_url + type: boolean + required: false + label: + en_US: Generate presigned URL + zh_Hans: 生成预签名URL + human_description: + en_US: Whether to generate a presigned URL for the S3 object + zh_Hans: 是否生成S3对象的预签名URL + default: false + form: form + - name: presign_expiry + type: number + required: false + label: + en_US: Presigned URL expiration time + zh_Hans: 预签名URL有效期 + human_description: + en_US: Expiration time in seconds for the presigned URL + zh_Hans: 预签名URL的有效期(秒) + default: 3600 + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index bffcd058b509bf..715b1ddeddcae5 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -2,7 +2,7 @@ import operator from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -10,8 +10,8 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None - topk: int = None + sagemaker_endpoint: str | None = None + topk: int | None = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index 1fafe09b4d96bf..55cff89798a4eb 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -17,7 +17,7 @@ class TTSModelType(Enum): class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None s3_client: Any = None comprehend_client: Any = None diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py index 7f69e833cb9046..a60062ca66abbf 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py @@ -1,6 +1,6 @@ from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py index a521f1c28a41b6..3e24b74d2598a7 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py @@ -1,7 +1,7 @@ from typing import Any, Union import httpx -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 12b4173fa40270..9aa781709a726c 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,7 +1,7 @@ import random from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index bab690af8292b7..a8127dd23f1553 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: ws = websocket.WebSocket() base_url = URL(credentials.get("base_url")) - ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + ws_protocol = "ws" + if base_url.scheme == "https": + ws_protocol = "wss" + ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123" try: ws.connect(ws_address) diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index bed9cd1882fa29..f994cdbf66e78b 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -40,7 +40,10 @@ def queue_prompt(self, client_id: str, prompt: dict) -> str: def open_websocket_connection(self) -> tuple[WebSocket, str]: client_id = str(uuid.uuid4()) ws = WebSocket() - ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws_protocol = "ws" + if self.base_url.scheme == "https": + ws_protocol = "wss" + ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}" ws.connect(ws_address) return ws, client_id diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py index c959496735e747..d58b42b82029ce 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -7,18 +7,22 @@ class SearchRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - view_id = tool_parameters.get("view_id") - field_names = tool_parameters.get("field_names") - sort = tool_parameters.get("sort") - filters = tool_parameters.get("filter") - page_token = tool_parameters.get("page_token") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + view_id = tool_parameters.get("view_id", "") + field_names = tool_parameters.get("field_names", "") + sort = tool_parameters.get("sort", "") + filters = tool_parameters.get("filter", "") + page_token = tool_parameters.get("page_token", "") automatic_fields = tool_parameters.get("automatic_fields", False) user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 20) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py index a7b036387500b0..31cf8e18d85b8d 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -7,14 +7,18 @@ class UpdateRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + records = tool_parameters.get("records", "") user_id_type = tool_parameters.get("user_id_type", "open_id") res = client.update_records(app_token, table_id, table_name, records, user_id_type) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py index 8f83aea5abbe3d..80287feca176e1 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py @@ -7,12 +7,16 @@ class AddEventAttendeesTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + event_id = tool_parameters.get("event_id", "") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "") need_notification = tool_parameters.get("need_notification", True) res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py index 144889692f9055..02e9b445219ac8 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py @@ -7,11 +7,15 @@ class DeleteEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") + event_id = tool_parameters.get("event_id", "") need_notification = tool_parameters.get("need_notification", True) res = client.delete_event(event_id, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py index a2cd5a8b17d0af..4dafe4b3baf0cd 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py @@ -7,8 +7,12 @@ class GetPrimaryCalendarTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) user_id_type = tool_parameters.get("user_id_type", "open_id") diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py index 8815b4c9c871cd..2e8ca968b3cc42 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py @@ -7,14 +7,18 @@ class ListEventsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") - page_token = tool_parameters.get("page_token") - page_size = tool_parameters.get("page_size") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", 50) res = client.list_events(start_time, end_time, page_token, page_size) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py index 85bcb1d3f63847..b20eb6c31828e4 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py @@ -7,16 +7,20 @@ class UpdateEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - summary = tool_parameters.get("summary") - description = tool_parameters.get("description") + event_id = tool_parameters.get("event_id", "") + summary = tool_parameters.get("summary", "") + description = tool_parameters.get("description", "") need_notification = tool_parameters.get("need_notification", True) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") auto_record = tool_parameters.get("auto_record", False) res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 090a0828e89bbf..1533f594172878 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,17 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get("title") - content = tool_parameters.get("content") - folder_token = tool_parameters.get("folder_token") + title = tool_parameters.get("title", "") + content = tool_parameters.get("content", "") + folder_token = tool_parameters.get("folder_token", "") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index dd57c6870d0ba9..8ea68a2ed87855 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -7,11 +7,15 @@ class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get("document_id") + document_id = tool_parameters.get("document_id", "") page_token = tool_parameters.get("page_token", "") user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 500) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 0dd55c65291783..756b7272248146 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -43,6 +43,13 @@ def _invoke( if wait_for_selector is not None and wait_for_selector != "": headers["X-Wait-For-Selector"] = wait_for_selector + remove_selector = tool_parameters.get("remove_selector") + if remove_selector is not None and remove_selector != "": + headers["X-Remove-Selector"] = remove_selector + + if tool_parameters.get("retain_images", False): + headers["X-Retain-Images"] = "true" + if tool_parameters.get("image_caption", False): headers["X-With-Generated-Alt"] = "true" @@ -59,6 +66,12 @@ def _invoke( if tool_parameters.get("no_cache", False): headers["X-No-Cache"] = "true" + if tool_parameters.get("with_iframe", False): + headers["X-With-Iframe"] = "true" + + if tool_parameters.get("with_shadow_dom", False): + headers["X-With-Shadow-Dom"] = "true" + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 589bc3433d9478..012a8c7688cb57 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -67,6 +67,33 @@ parameters: pt_BR: css selector para aguardar elementos específicos llm_description: css selector of the target element to wait for form: form + - name: remove_selector + type: string + required: false + label: + en_US: Excluded Selector + zh_Hans: 排除选择器 + pt_BR: Seletor Excluído + human_description: + en_US: css selector for remove for specific elements + zh_Hans: css 选择器用于排除特定元素 + pt_BR: seletor CSS para remover elementos específicos + llm_description: css selector of the target element to remove for + form: form + - name: retain_images + type: boolean + required: false + default: false + label: + en_US: Remove All Images + zh_Hans: 删除所有图片 + pt_BR: Remover todas as imagens + human_description: + en_US: Removes all images from the response. + zh_Hans: 从响应中删除所有图片。 + pt_BR: Remove todas as imagens da resposta. + llm_description: Remove all images + form: form - name: image_caption type: boolean required: false @@ -136,6 +163,34 @@ parameters: pt_BR: Ignorar o cache llm_description: bypass the cache form: form + - name: with_iframe + type: boolean + required: false + default: false + label: + en_US: Enable iframe extraction + zh_Hans: 启用 iframe 提取 + pt_BR: Habilitar extração de iframe + human_description: + en_US: Extract and process content of all embedded iframes in the DOM tree. + zh_Hans: 提取并处理 DOM 树中所有嵌入 iframe 的内容。 + pt_BR: Extrair e processar o conteúdo de todos os iframes incorporados na árvore DOM. + llm_description: Extract content from embedded iframes + form: form + - name: with_shadow_dom + type: boolean + required: false + default: false + label: + en_US: Enable Shadow DOM extraction + zh_Hans: 启用 Shadow DOM 提取 + pt_BR: Habilitar extração de Shadow DOM + human_description: + en_US: Traverse all Shadow DOM roots in the document and extract content. + zh_Hans: 遍历文档中所有 Shadow DOM 根并提取内容。 + pt_BR: Percorra todas as raízes do Shadow DOM no documento e extraia o conteúdo. + llm_description: Extract content from Shadow DOM roots + form: form - name: summary type: boolean required: false diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index fcab3d71a93cf9..06f6cacd5d6126 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 793c74e5f9df51..e825329a6d8f61 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index f91432ee77f488..193017ba9a7c53 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index 383825c2d0b259..feca0d8a7c2783 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index 0c5b5e41cbe1e1..d3a497d1cd5c54 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -1,7 +1,7 @@ import logging from typing import Any, Union -import numexpr as ne +import numexpr as ne # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index db4adfd4ad4629..6473c509e1f4c2 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -1,4 +1,4 @@ -from novita_client import ( +from novita_client import ( # type: ignore Txt2ImgV3Embedding, Txt2ImgV3HiresFix, Txt2ImgV3LoRA, diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index 0b4f2edff3607f..097b234bd50640 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9c61eab9f95784..297a27abba667a 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py index 165e93956eff38..704e0015d961a3 100644 --- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -13,7 +13,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from pydub import AudioSegment + from pydub import AudioSegment # type: ignore class PodcastAudioGeneratorTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index d8ca20bde6ffc9..4a47c4211f4fd4 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -2,10 +2,10 @@ import logging from typing import Any, Union -from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q -from qrcode.image.base import BaseImage -from qrcode.image.pure import PyPNGImage -from qrcode.main import QRCode +from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore +from qrcode.image.base import BaseImage # type: ignore +from qrcode.image.pure import PyPNGImage # type: ignore +from qrcode.main import QRCode # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/transcript/tools/transcript.py b/api/core/tools/provider/builtin/transcript/tools/transcript.py index 27f700efbd6936..ac7565d9eef5b8 100644 --- a/api/core/tools/provider/builtin/transcript/tools/transcript.py +++ b/api/core/tools/provider/builtin/transcript/tools/transcript.py @@ -1,7 +1,7 @@ from typing import Any, Union from urllib.parse import parse_qs, urlparse -from youtube_transcript_api import YouTubeTranscriptApi +from youtube_transcript_api import YouTubeTranscriptApi # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 5ee839baa56f02..98a108f4ec7e93 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel): def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: - from twilio.rest import Client + from twilio.rest import Client # type: ignore except ImportError: raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index b1d100aad93dba..649e03d185121c 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -1,7 +1,7 @@ from typing import Any -from twilio.base.exceptions import TwilioRestException -from twilio.rest import Client +from twilio.base.exceptions import TwilioRestException # type: ignore +from twilio.rest import Client # type: ignore from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 1c7cb39c92b40b..a6afd2dddfc63a 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -1,6 +1,6 @@ from typing import Any, Union -from vanna.remote import VannaDefault +from vanna.remote import VannaDefault # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError @@ -14,6 +14,9 @@ def _invoke( """ invoke tools """ + # Ensure runtime and credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") api_key = self.runtime.credentials.get("api_key", None) if not api_key: raise ToolProviderCredentialValidationError("Please input api key") diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 211ec78f4d6a58..9d7613f8eaf170 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,32 +1,8 @@ from typing import Any -from core.file import FileTransferMethod, FileType -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - mapping = { - "transfer_method": FileTransferMethod.TOOL_FILE, - "type": FileType.IMAGE, - "id": "test_id", - "url": "https://cloud.dify.ai/logo/logo-site.png", - } - test_img = file_factory.build_from_mapping( - mapping=mapping, - tenant_id="__test_123", - ) - try: - VectorizerTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id="", - tool_parameters={"mode": "test", "image": test_img}, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + return diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index cb88e9519a4346..edb96e722f7f33 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -import wikipedia +import wikipedia # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index f044fbe5404b0a..95a65ba22fc8af 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -3,7 +3,7 @@ import pandas as pd from requests.exceptions import HTTPError, ReadTimeout -from yfinance import download +from yfinance import download # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index ff820430f9f366..c9ae0c4ca7fcc6 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -1,6 +1,6 @@ from typing import Any, Union -import yfinance +import yfinance # type: ignore from requests.exceptions import HTTPError, ReadTimeout from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index dfc7e460473c33..74d0d25addf04b 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -1,7 +1,7 @@ from typing import Any, Union from requests.exceptions import HTTPError, ReadTimeout -from yfinance import Ticker +from yfinance import Ticker # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 95dec2eac9a752..a24fe89679b29b 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Union -from googleapiclient.discovery import build +from googleapiclient.discovery import build # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3b4513..61de75ac5e2ccd 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -1,6 +1,6 @@ from abc import abstractmethod from os import listdir, path -from typing import Any +from typing import Any, Optional from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType @@ -50,6 +50,8 @@ def _get_builtin_tools(self) -> list[Tool]: """ if self.tools: return self.tools + if not self.identity: + return [] provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") @@ -86,7 +88,7 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: return self.credentials_schema.copy() - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -94,11 +96,14 @@ def get_tools(self) -> list[Tool]: """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ValueError("tools not found") + return next((t for t in tools if t.identity and t.identity.name == tool_name), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ @@ -107,10 +112,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def need_credentials(self) -> bool: @@ -144,6 +152,8 @@ def _get_tool_labels(self) -> list[ToolLabelEnum]: """ returns the labels of the provider """ + if self.identity is None: + return [] return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: @@ -159,56 +169,56 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for parameter_name in tool_parameters: + if parameter_name not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[parameter_name], int | float): + raise ToolParameterValidationError(f"parameter {parameter_name} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {parameter_name} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {parameter_name} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[parameter_name], bool): + raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {parameter_name} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[parameter_name] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(parameter_name) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for parameter_name in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {parameter_name} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value + tool_parameters[parameter_name] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a11562b717..e35207e4f06404 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -24,10 +24,12 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: :return: the credentials schema """ + if self.credentials_schema is None: + return {} return self.credentials_schema.copy() @abstractmethod - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -36,7 +38,7 @@ def get_tools(self) -> list[Tool]: pass @abstractmethod - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns a tool that the provider can provide @@ -51,10 +53,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def provider_type(self) -> ToolProviderType: @@ -78,55 +83,55 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for tool_parameter in tool_parameters: + if tool_parameter not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[tool_parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[tool_parameter], int | float): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {tool_parameter} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {tool_parameter} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[tool_parameter], bool): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[tool_parameter] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(tool_parameter) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for tool_parameter_validate in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) + tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ @@ -144,6 +149,8 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: for credential_name in credentials: if credential_name not in credentials_need_to_validate: + if self.identity is None: + raise ValueError("identity is not set") raise ToolProviderCredentialValidationError( f"credential {credential_name} not found in provider {self.identity.name}" ) diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 5656dd09ab8c94..17fe2e20cf282e 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -11,6 +11,7 @@ ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -116,6 +117,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=variable.required, options=options, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif features.file_upload: @@ -128,6 +130,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=False, form=parameter.form, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) else: @@ -157,7 +160,7 @@ def fetch_workflow_variable(variable_name: str): label=db_provider.label, ) - def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -168,7 +171,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = ( + db_providers: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter( WorkflowToolProvider.tenant_id == tenant_id, @@ -179,12 +182,14 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if not db_providers: return [] + if not db_providers.app: + raise ValueError("app not found") self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ get tool by name @@ -195,6 +200,8 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: return None for tool in self.tools: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 0b4c5bd2c6fd73..9a00450290a660 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -32,11 +32,13 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": :param meta: the meta data of a tool call processing, tenant_id is required :return: the new tool """ + if self.api_bundle is None: + raise ValueError("api_bundle is required") return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, - api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, + api_bundle=self.api_bundle.model_copy(), runtime=Tool.Runtime(**runtime), ) @@ -61,6 +63,8 @@ def tool_provider_type(self) -> ToolProviderType: def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} + if self.runtime is None: + raise ValueError("runtime is required") credentials = self.runtime.credentials or {} if "auth_type" not in credentials: @@ -88,7 +92,7 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers[api_key_header] = credentials["api_key_value"] - needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] + needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") @@ -137,7 +141,8 @@ def do_http_request( params = {} path_params = {} - body = {} + # FIXME: body should be a dict[str, Any] but it changed a lot in this function + body: Any = {} cookies = {} files = [] @@ -198,7 +203,7 @@ def do_http_request( body = body if method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, method)( + response: httpx.Response = getattr(ssrf_proxy, method)( url, params=params, headers=headers, @@ -210,7 +215,7 @@ def do_http_request( ) return response else: - raise ValueError(f"Invalid http method {self.method}") + raise ValueError(f"Invalid http method {method}") def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 @@ -270,9 +275,6 @@ def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> A elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: - # an array str like '[1,2]' also can convert to list [1,2] through json.loads - # json not support single quote, but we can support it - value = value.replace("'", '"') return json.loads(value) except ValueError: return value @@ -291,6 +293,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke http request """ + response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index e2a81ed0a36edd..adda4297f38e8a 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage @@ -32,9 +32,12 @@ def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: :return: the model result """ # invoke model + if self.runtime is None or self.identity is None: + raise ValueError("runtime and identity are required") + return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, @@ -50,8 +53,11 @@ def get_max_tokens(self) -> int: :param model_config: the model config :return: the max tokens """ + if self.runtime is None: + raise ValueError("runtime is required") + return ModelInvocationUtils.get_max_llm_context_tokens( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -61,7 +67,12 @@ def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: :param prompt_messages: the prompt messages :return: the tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) + if self.runtime is None: + raise ValueError("runtime is required") + + return ModelInvocationUtils.calculate_tokens( + tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() @@ -81,7 +92,7 @@ def summarize(content: str) -> str: stop=[], ) - return summary.message.content + return cast(str, summary.message.content) lines = content.split("\n") new_lines = [] @@ -102,16 +113,16 @@ def summarize(content: str) -> str: # merge lines into messages with max tokens messages: list[str] = [] - for i in new_lines: + for j in new_lines: if len(messages) == 0: - messages.append(i) + messages.append(j) else: - if len(messages[-1]) + len(i) < max_tokens * 0.5: - messages[-1] += i - if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: - messages.append(i) + if len(messages[-1]) + len(j) < max_tokens * 0.5: + messages[-1] += j + if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7: + messages.append(j) else: - messages[-1] += i + messages[-1] += j summaries = [] for i in range(len(messages)): diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index ab7b40a2536db8..a4afea4b9df429 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,13 +8,14 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -44,12 +46,12 @@ def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): def _run(self, query: str) -> str: threads = [] - all_documents = [] + all_documents: list[RagDocument] = [] for dataset_id in self.dataset_ids: retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "all_documents": all_documents, @@ -77,11 +79,11 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get("score"): + if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), @@ -139,6 +141,7 @@ def _run(self, query: str) -> str: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) + return "" def _retriever( self, diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index dad8c773579099..a4d2de3b1c8ef3 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Optional -from msal_extensions.persistence import ABC +from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 987f94a35046e9..b382016473055d 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -69,25 +71,27 @@ def _run(self, query: str) -> str: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) # deal with external documents context_list = [] for position, item in enumerate(results, start=1): - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -95,7 +99,7 @@ def _run(self, query: str) -> str: return str("\n".join([item.page_content for item in results])) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -113,11 +117,11 @@ def _run(self, query: str) -> str: score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) + reranking_model=retrieval_model.get("reranking_model") if retrieval_model["reranking_enable"] else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + weights=retrieval_model.get("weights"), ) else: documents = [] @@ -127,7 +131,7 @@ def _run(self, query: str) -> str: document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get("score"): + if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in documents] @@ -155,20 +159,21 @@ def _run(self, query: str) -> str: context_list = [] resource_number = 1 for segment in sorted_segments: - context = {} - document = Document.query.filter( + document_segment = Document.query.filter( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() - if dataset and document: + if not document_segment: + continue + if dataset and document_segment: source = { "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, + "document_id": document_segment.id, + "document_name": document_segment.name, + "data_source_type": document_segment.data_source_type, "segment_id": segment.id, "retriever_from": self.retriever_from, "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 3c9295c493c470..2d7e193e152645 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool): def get_dataset_tools( tenant_id: str, dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, + retrieve_config: Optional[DatasetRetrieveConfigEntity], return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, @@ -51,6 +51,8 @@ def get_dataset_tools( invoke_from=invoke_from, hit_callback=hit_callback, ) + if retrieval_tools is None: + return [] # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -83,6 +85,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: llm_description="Query for the dataset to be used to retrieve the dataset.", required=True, default="", + placeholder=I18nObject(en_US="", zh_Hans=""), ), ] @@ -102,7 +105,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text=result) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials for dataset retriever tool """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 8d4045038171a6..55f94d7619635b 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -91,7 +91,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - def load_variables(self, variables: ToolRuntimeVariablePool): + def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: """ load variables from database @@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_file(self.identity.name, variable_name, image_key) @@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_text(self.identity.name, variable_name, text) @@ -200,7 +204,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters # TODO: Fix type error. + if self.runtime is None: + return [] if self.runtime.runtime_parameters: + # Convert Mapping to dict before updating + tool_parameters = dict(tool_parameters) tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type @@ -221,7 +229,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> Transform tool parameters type """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) + result: dict[str, Any] = deepcopy(dict(tool_parameters)) for parameter in self.parameters or []: if parameter.name in tool_parameters: result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) @@ -234,12 +242,15 @@ def _invoke( ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials :param credentials: the credentials :param parameters: the parameters + :param format_only: only return the formatted """ pass diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 33b4ad021a5e7f..edff4a2d07cca2 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -68,20 +68,20 @@ def _invoke( if data.get("error"): raise Exception(data.get("error")) - result = [] + r = [] outputs = data.get("outputs") if outputs == None: outputs = {} else: - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_message(file)) + outputs, extracted_files = self._extract_files(outputs) + for f in extracted_files: + r.append(self.create_file_message(f)) - result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) - result.append(self.create_json_message(outputs)) + r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + r.append(self.create_json_message(outputs)) - return result + return r def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f92b43608ed935..425a892527daa4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -3,7 +3,7 @@ from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from yarl import URL @@ -46,7 +46,7 @@ def agent_invoke( invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: + ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -69,6 +69,8 @@ def agent_invoke( raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool + if tool.identity is None: + raise ValueError("tool identity is not set") try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) @@ -163,6 +165,8 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke """ Invoke the tool with the given arguments. """ + if tool.identity is None: + raise ValueError("tool identity is not set") started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, @@ -171,7 +175,7 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke "tool_name": tool.identity.name, "tool_provider": tool.identity.provider, "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {}, "tool_icon": tool.identity.icon, }, ) @@ -194,9 +198,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: - result += response.message + result += str(response.message) if response.message is not None else "" elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message!r}. please tell user to check it." elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," @@ -205,7 +209,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message!r}." return result @@ -223,7 +227,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype = response.meta.get("mime_type") else: try: - url = URL(response.message) + url = URL(cast(str, response.message)) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -237,7 +241,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -245,7 +249,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -257,7 +261,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 5052f0897a2958..2aaca6d82e36b1 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,9 +8,10 @@ from typing import Optional, Union from uuid import uuid4 -from httpx import get +import httpx from configs import dify_config +from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -94,12 +95,11 @@ def create_file_by_url( ) -> ToolFile: # try to download image try: - response = get(file_url) + response = ssrf_proxy.get(file_url) response.raise_for_status() blob = response.content - except Exception as e: - logger.exception(f"Failed to download file from {file_url}") - raise + except httpx.TimeoutException as e: + raise ValueError(f"timeout when downloading file from {file_url}") mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 2a5a2944ef8471..e53985951b0627 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -84,13 +84,17 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - provider_ids = [controller.provider_id for controller in tool_providers] + provider_ids = [ + controller.provider_id + for controller in tool_providers + if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController)) + ] labels: list[ToolLabelBinding] = ( db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels = {label.tool_id: [] for label in labels} + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ac333162b6bb1c..5b2173a4d0ad69 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ from collections.abc import Generator from os import listdir, path from threading import Lock, Thread -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from configs import dify_config from core.agent.entities import AgentToolEntity @@ -15,15 +15,18 @@ from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter -from core.tools.errors import ToolProviderNotFoundError +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -33,9 +36,9 @@ class ToolManager: _builtin_provider_lock = Lock() - _builtin_providers = {} + _builtin_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels = {} + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -55,7 +58,7 @@ def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: return cls._builtin_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: + def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]: """ get the builtin tool @@ -66,13 +69,15 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") return tool @classmethod def get_tool( cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool @@ -103,7 +108,7 @@ def get_tool_runtime( tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool runtime @@ -113,6 +118,7 @@ def get_tool_runtime( :return: the tool """ + controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController] if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) @@ -129,7 +135,7 @@ def get_tool_runtime( ) # get credentials - builtin_provider: BuiltinToolProvider = ( + builtin_provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -177,7 +183,7 @@ def get_tool_runtime( } ) elif provider_type == "workflow": - workflow_provider = ( + workflow_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -187,8 +193,13 @@ def get_tool_runtime( raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: Optional[list[Tool]] = controller.get_tools( + user_id="", tenant_id=workflow_provider.tenant_id + ) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + return controller_tools[0].fork_tool_runtime( runtime={ "tenant_id": tenant_id, "credentials": {}, @@ -215,7 +226,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: # check if tool_parameter_config in options - options = [x.value for x in parameter_rule.options] + options = [x.value for x in parameter_rule.options or []] if parameter_value is not None and parameter_value not in options: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" @@ -267,6 +278,8 @@ def get_agent_tool_runtime( identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -312,6 +325,9 @@ def get_workflow_tool_runtime( if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -326,6 +342,8 @@ def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ # get provider provider_controller = cls.get_builtin_provider(provider) + if provider_controller.identity is None: + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") absolute_path = path.join( path.dirname(path.realpath(__file__)), @@ -381,11 +399,15 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non ), parent_type=BuiltinToolProviderController, ) - provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.identity.name] = provider - for tool in provider.get_tools(): + provider_controller: BuiltinToolProviderController = provider_class() + if provider_controller.identity is None: + continue + cls._builtin_providers[provider_controller.identity.name] = provider_controller + for tool in provider_controller.get_tools() or []: + if tool.identity is None: + continue cls._builtin_tools_labels[tool.identity.name] = tool.identity.label - yield provider + yield provider_controller except Exception as e: logger.exception(f"load builtin provider {provider}") @@ -449,9 +471,11 @@ def user_list_providers( # append builtin providers for provider in builtin_providers: # handle include, exclude + if provider.identity is None: + continue if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), data=provider, name_func=lambda x: x.identity.name, ): @@ -472,7 +496,7 @@ def user_list_providers( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers = [ + api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] @@ -495,7 +519,7 @@ def user_list_providers( db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers = [] + workflow_provider_controllers: list[WorkflowToolProviderController] = [] for provider in workflow_providers: try: workflow_provider_controllers.append( @@ -505,7 +529,9 @@ def user_list_providers( # app has been deleted pass - labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( @@ -527,7 +553,7 @@ def get_api_provider_controller( :return: the provider controller, the credentials """ - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.id == provider_id, @@ -556,7 +582,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: get tool provider """ provider_name = provider - provider: ApiToolProvider = ( + provider_tool: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -565,17 +591,18 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: .first() ) - if provider is None: + if provider_tool is None: raise ValueError(f"you have not added provider {provider_name}") try: - credentials = json.loads(provider.credentials_str) or {} + credentials = json.loads(provider_tool.credentials_str) or {} except: credentials = {} # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE + provider_tool, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -584,25 +611,28 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) try: - icon = json.loads(provider.icon) + icon = json.loads(provider_tool.icon) except: icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder( - { - "schema_type": provider.schema_type, - "schema": provider.schema, - "tools": provider.tools, - "icon": icon, - "description": provider.description, - "credentials": masked_credentials, - "privacy_policy": provider.privacy_policy, - "custom_disclaimer": provider.custom_disclaimer, - "labels": labels, - } + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_tool.schema_type, + "schema": provider_tool.schema, + "tools": provider_tool.tools, + "icon": icon, + "description": provider_tool.description, + "credentials": masked_credentials, + "privacy_policy": provider_tool.privacy_policy, + "custom_disclaimer": provider_tool.custom_disclaimer, + "labels": labels, + } + ), ) @classmethod @@ -617,6 +647,7 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> """ provider_type = provider_type provider_id = provider_id + provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None if provider_type == "builtin": return ( dify_config.CONSOLE_API_URL @@ -626,16 +657,21 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> ) elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) - return json.loads(provider.icon) + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} except: return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == "workflow": - provider: WorkflowToolProvider = ( + provider = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -643,7 +679,13 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> if provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return json.loads(provider.icon) + try: + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} + except: + return {"background": "#252525", "content": "\ud83d\ude01"} else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 8b5e27f5382ee7..d7720928644701 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,9 +72,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() @@ -95,9 +99,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return credentials def delete_tool_credentials_cache(self): + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() @@ -199,6 +207,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return a deep copy of parameters with decrypted values """ + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", @@ -232,6 +243,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return parameters def delete_tool_parameters_cache(self): + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index ea28037df03720..ecf60045aa8dc5 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -101,7 +101,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: """ url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -126,15 +126,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: @@ -155,9 +156,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(str, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -173,9 +174,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -191,9 +193,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -203,7 +206,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -227,9 +230,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -245,9 +249,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -260,9 +265,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -278,9 +284,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -289,7 +296,7 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -300,7 +307,7 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -312,9 +319,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -322,9 +330,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -347,9 +356,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -363,7 +373,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -376,7 +386,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -384,7 +394,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -395,9 +405,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -418,9 +429,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -431,9 +443,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -447,9 +460,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -463,9 +477,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -477,9 +492,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -499,9 +515,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -521,9 +538,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -545,9 +563,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -569,9 +588,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -593,9 +613,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -609,9 +630,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -633,9 +655,10 @@ def add_records( payload = { "records": convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -657,9 +680,10 @@ def update_records( payload = { "records": convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -686,9 +710,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -740,7 +765,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -752,10 +777,11 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -767,9 +793,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -797,9 +824,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -834,9 +862,10 @@ def delete_tables( "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -852,9 +881,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -882,7 +912,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params, payload=payload) + res: dict = self._send_request(url, method="GET", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py index 30cb0cb141d9a6..de394a39bf5a00 100644 --- a/api/core/tools/utils/lark_api_utils.py +++ b/api/core/tools/utils/lark_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -62,12 +62,10 @@ def convert_update_records(self, json_str): def tenant_access_token(self) -> str: feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" if redis_client.exists(feishu_tenant_access_token): - return redis_client.get(feishu_tenant_access_token).decode() - res = self.get_tenant_access_token(self.app_id, self.app_secret) + return str(redis_client.get(feishu_tenant_access_token).decode()) + res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret) redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) - if "tenant_access_token" in res: - return res.get("tenant_access_token") - return "" + return res.get("tenant_access_token", "") def _send_request( self, @@ -91,7 +89,7 @@ def _send_request( def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -101,15 +99,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: @@ -119,9 +118,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(dict, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -134,9 +133,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -149,9 +149,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -161,7 +162,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -182,9 +183,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -197,9 +199,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -211,9 +214,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -228,9 +232,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -238,9 +243,10 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -250,9 +256,10 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -263,9 +270,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -273,9 +281,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -298,9 +307,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -314,7 +324,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -327,7 +337,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -335,7 +345,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -346,9 +356,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -369,9 +380,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -381,9 +393,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -396,9 +409,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -411,9 +425,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -424,9 +439,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -445,9 +461,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -466,9 +483,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -489,9 +507,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -512,9 +531,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -535,9 +555,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -550,9 +571,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -573,9 +595,10 @@ def add_records( payload = { "records": self.convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -596,9 +619,10 @@ def update_records( payload = { "records": self.convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -624,9 +648,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -678,7 +703,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -690,9 +715,10 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -703,9 +729,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -732,9 +759,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -767,9 +795,10 @@ def delete_tables( "table_ids": table_id_list, "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -784,9 +813,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -814,7 +844,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="POST", params=params, payload=payload) + res: dict = self._send_request(url, method="POST", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index e30c903a4b1146..3509f1e6e59f77 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -90,12 +90,12 @@ def transform_tool_invoke_messages( ) elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - file = message.meta.get("file") - if isinstance(file, File): - if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) - if file.type == FileType.IMAGE: + file_mata = message.meta.get("file") + if isinstance(file_mata, File): + if file_mata.transfer_method == FileTransferMethod.TOOL_FILE: + assert file_mata.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension) + if file_mata.type == FileType.IMAGE: result.append( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 4e226810d6ac90..3689dcc9e5ebfd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ """ import json -from typing import cast +from typing import Optional, cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ def get_max_llm_context_tokens( if not schema: raise InvokeModelError("No model schema found") - max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 @@ -133,14 +133,17 @@ def invoke( db.session.commit() try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index ae44b1b99d447a..f1dc1123b9935f 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -6,7 +6,7 @@ from typing import Optional from requests import get -from yaml import YAMLError, safe_load +from yaml import YAMLError, safe_load # type: ignore from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -64,6 +64,9 @@ def parse_openapi_to_tool_bundle( default=parameter["schema"]["default"] if "schema" in parameter and "default" in parameter["schema"] else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -108,6 +111,9 @@ def parse_openapi_to_tool_bundle( form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), default=property.get("default", None), + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -158,9 +164,9 @@ def parse_openapi_to_tool_bundle( return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: parameter = parameter or {} - typ = None + typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE @@ -175,6 +181,8 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType return ToolParameter.ToolParameterType.BOOLEAN elif typ == "string": return ToolParameter.ToolParameterType.STRING + else: + return None @staticmethod def parse_openapi_yaml_to_tool_bundle( @@ -236,7 +244,8 @@ def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): - warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." openapi["paths"][path][method] = { "operationId": operation["operationId"], diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 3aae31e93a1304..d42fd99fce5e80 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,13 +9,13 @@ import unicodedata from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Any, Literal, Optional, cast from urllib.parse import unquote import chardet -import cloudscraper -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from regex import regex +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: @@ -125,7 +125,7 @@ def extract_using_readabilipy(html): os.unlink(article_json_path) os.unlink(html_path) - article_json = { + article_json: dict[str, Any] = { "title": None, "byline": None, "date": None, @@ -300,7 +300,7 @@ def strip_control_characters(text): def normalize_unicode(text): """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" text = unicodedata.normalize(normal_form, text) return text @@ -332,6 +332,7 @@ def add_content_digest(element): def content_digest(element): + digest: Any if is_text(element): # Hash trimmed_string = element.string.strip() diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d92bfb9b90a9aa..08a112cfdb2b91 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,7 +7,7 @@ class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: Mapping[str, Any]): + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): for configuration in configurations: WorkflowToolParameterConfiguration.model_validate(configuration) @@ -27,7 +27,7 @@ def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[Vari @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> None: + ) -> bool: """ check is synced diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 42c7f85bc6daeb..ee7ca11e056625 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -import yaml +import yaml # type: ignore from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 2b1a58f93aa3c6..7a1cbf99407ea8 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -21,6 +21,7 @@ ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + ArrayVariable, FileVariable, FloatVariable, IntegerVariable, @@ -43,6 +44,7 @@ "ArraySegment", "ArrayStringSegment", "ArrayStringVariable", + "ArrayVariable", "FileSegment", "FileVariable", "FloatSegment", diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c902303eef54d4..c32815b24d02ed 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import cast from uuid import uuid4 from pydantic import Field @@ -10,6 +11,7 @@ ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArraySegment, ArrayStringSegment, FileSegment, FloatSegment, @@ -52,19 +54,23 @@ class ObjectVariable(ObjectSegment, Variable): pass -class ArrayAnyVariable(ArrayAnySegment, Variable): +class ArrayVariable(ArraySegment, Variable): pass -class ArrayStringVariable(ArrayStringSegment, Variable): +class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): pass -class ArrayNumberVariable(ArrayNumberSegment, Variable): +class ArrayStringVariable(ArrayStringSegment, ArrayVariable): pass -class ArrayObjectVariable(ArrayObjectSegment, Variable): +class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): + pass + + +class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): pass @@ -73,7 +79,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return cast(str, encrypter.obfuscated_token(self.value)) class NoneVariable(NoneSegment, Variable): @@ -85,5 +91,5 @@ class FileVariable(FileSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): +class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index 17913de7b0d2ce..b9c6b35ad3476a 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -32,13 +33,15 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: - self.current_node_id = None + self.current_node_id: Optional[str] = None def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunPartialSucceededEvent): + self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink") elif isinstance(event, GraphRunFailedEvent): self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index e174d3baa0c736..ae5f117bf9b121 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum): PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs + ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field class NodeRunResult(BaseModel): @@ -35,11 +36,15 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict[str, Any]] = None # process data + process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + error_type: Optional[str] = None # error type if status is failed + + # single step node run retry + retry_index: int = 0 diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index bc3a15bd004ace..b8470aecbd83a2 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,7 +5,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState): """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 3736e632c3f1eb..d591b68e7e72be 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from typing import Any, Optional @@ -32,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") + exceptions_count: int = Field(description="exception count", default=0) + + +class GraphRunPartialSucceededEvent(BaseGraphEvent): + exceptions_count: int = Field(..., description="exception count") + outputs: Optional[dict[str, Any]] = None ########################################### @@ -82,10 +89,20 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunExceptionEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunRetryEvent(NodeRunStartedEvent): + error: str = Field(..., description="error") + retry_index: int = Field(..., description="which retry attempt is about to be performed") + start_at: datetime = Field(..., description="retry start time") + + ########################################### # Parallel Branch Events ########################################### @@ -140,8 +157,8 @@ class BaseIterationEvent(GraphEngineEvent): class IterationRunStartedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None predecessor_node_id: Optional[str] = None @@ -153,18 +170,18 @@ class IterationRunNextEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None steps: int = 0 iteration_duration_map: Optional[dict[str, float]] = None class IterationRunFailedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d87c039409d62e..b3bcc3b2ccc309 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,9 +1,11 @@ import uuid +from collections import defaultdict from collections.abc import Mapping from typing import Any, Optional, cast from pydantic import BaseModel, Field +from configs import dify_config from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter @@ -64,13 +66,21 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] + # node configs + node_configs = graph_config.get("nodes") + if not node_configs: + raise ValueError("Graph must have at least one node") edge_configs = cast(list, edge_configs) + node_configs = cast(list, node_configs) # reorganize edges mapping edge_mapping: dict[str, list[GraphEdge]] = {} reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() + fail_branch_source_node_id = [ + node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch" + ] for edge_config in edge_configs: source_node_id = edge_config.get("source") if not source_node_id: @@ -90,8 +100,16 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non # parse run condition run_condition = None - if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": - run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) + if edge_config.get("sourceHandle"): + if ( + edge_config.get("source") in fail_branch_source_node_id + and edge_config.get("sourceHandle") != "fail-branch" + ): + run_condition = RunCondition(type="branch_identify", branch_identify="success-branch") + elif edge_config.get("sourceHandle") != "source": + run_condition = RunCondition( + type="branch_identify", branch_identify=edge_config.get("sourceHandle") + ) graph_edge = GraphEdge( source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition @@ -100,13 +118,6 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) - # node configs - node_configs = graph_config.get("nodes") - if not node_configs: - raise ValueError("Graph must have at least one node") - - node_configs = cast(list, node_configs) - # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} @@ -161,7 +172,9 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, + level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + parent_parallel_id=parallel.parent_parallel_id, ) # init answer stream generate routes @@ -298,26 +311,17 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = {} - condition_edge_mappings = {} + parallel_branch_node_ids = defaultdict(list) + condition_edge_mappings = defaultdict(list) for graph_edge in target_node_edges: if graph_edge.run_condition is None: - if "default" not in parallel_branch_node_ids: - parallel_branch_node_ids["default"] = [] - parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if condition_hash not in condition_edge_mappings: - condition_edge_mappings[condition_hash] = [] - condition_edge_mappings[condition_hash].append(graph_edge) for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: - if condition_hash not in parallel_branch_node_ids: - parallel_branch_node_ids[condition_hash] = [] - for graph_edge in graph_edges: parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) @@ -406,7 +410,7 @@ def _recursively_add_parallels( if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): for graph_edge in graph_edges: - current_parallel: GraphParallel | None = cls._get_current_parallel( + current_parallel = cls._get_current_parallel( parallel_mapping=parallel_mapping, graph_edge=graph_edge, parallel=condition_parallels.get(condition_hash), diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index baeec9bf0160d7..7683dcc9dcd3c0 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -15,6 +15,7 @@ class Status(Enum): SUCCESS = "success" FAILED = "failed" PAUSED = "paused" + EXCEPTION = "exception" id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" @@ -51,7 +52,11 @@ def set_finished(self, run_result: NodeRunResult) -> None: :param run_result: run result """ - if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: + if self.status in { + RouteNodeState.Status.SUCCESS, + RouteNodeState.Status.FAILED, + RouteNodeState.Status.EXCEPTION, + }: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -59,6 +64,9 @@ def set_finished(self, run_result: NodeRunResult) -> None: elif run_result.status == WorkflowNodeExecutionStatus.FAILED: self.status = RouteNodeState.Status.FAILED self.failed_reason = run_result.error + elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + self.status = RouteNodeState.Status.EXCEPTION + self.failed_reason = run_result.error else: raise Exception(f"Invalid route status {run_result.status}") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 7cffd7bc8e1659..db1e01f14fda59 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,23 +5,28 @@ from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from typing import Any, Optional +from datetime import UTC, datetime +from typing import Any, Optional, cast from flask import Flask, current_app +from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( BaseIterationEvent, GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -35,8 +40,11 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from extensions.ext_database import db @@ -48,13 +56,18 @@ class GraphEngineThreadPool(ThreadPoolExecutor): def __init__( - self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + self, + max_workers=None, + thread_name_prefix="", + initializer=None, + initargs=(), + max_submit_count=dify_config.MAX_SUBMIT_COUNT, ) -> None: super().__init__(max_workers, thread_name_prefix, initializer, initargs) self.max_submit_count = max_submit_count self.submit_count = 0 - def submit(self, fn, *args, **kwargs): + def submit(self, fn, /, *args, **kwargs): self.submit_count += 1 self.check_is_full() @@ -88,7 +101,7 @@ def __init__( max_execution_time: int, thread_pool_id: Optional[str] = None, ) -> None: - thread_pool_max_submit_count = 100 + thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 # init thread pool @@ -128,6 +141,8 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() + handle_exceptions: list[str] = [] + stream_processor: StreamProcessor try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -140,18 +155,22 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ) # run graph - generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) - + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) + ) for item in generator: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") + yield GraphRunFailedEvent( + error=item.route_node_state.failed_reason or "Unknown error.", + exceptions_count=len(handle_exceptions), + ) return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: self.graph_runtime_state.outputs = ( - item.route_node_state.node_run_result.outputs + dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result and item.route_node_state.node_run_result.outputs else {} @@ -172,19 +191,24 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ].strip() except Exception as e: logger.exception("Graph run failed") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) return - - # trigger graph run success event - yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + # count exceptions to determine partial success + if len(handle_exceptions) > 0: + yield GraphRunPartialSucceededEvent( + exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs + ) + else: + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) self._release_thread() except GraphRunFailedError as e: - yield GraphRunFailedEvent(error=e.error) + yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions)) self._release_thread() return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) self._release_thread() raise e @@ -198,6 +222,7 @@ def _run( in_parallel_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: @@ -242,7 +267,7 @@ def _run( previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) try: # run node generator = self._run_node( @@ -252,6 +277,7 @@ def _run( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: @@ -301,7 +327,12 @@ def _run( if len(edge_mappings) == 1: edge = edge_mappings[0] - + if ( + previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and edge.run_condition is None + ): + break if edge.run_condition: result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -321,7 +352,7 @@ def _run( if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - condition_edge_mappings = {} + condition_edge_mappings: dict[str, list[GraphEdge]] = {} for edge in edge_mappings: if edge.run_condition: run_condition_hash = edge.run_condition.hash @@ -334,7 +365,10 @@ def _run( if len(sub_edge_mappings) == 0: continue - edge = sub_edge_mappings[0] + edge = cast(GraphEdge, sub_edge_mappings[0]) + if edge.run_condition is None: + logger.warning(f"Edge {edge.target_node_id} run condition is None") + continue result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -355,13 +389,14 @@ def _run( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for parallel_result in parallel_generator: + if isinstance(parallel_result, str): + final_node_id = parallel_result else: - yield item + yield parallel_result break @@ -369,18 +404,25 @@ def _run( break next_node_id = final_node_id + elif ( + node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node_instance.should_continue_on_error + and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + ): + break else: parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for generated_item in parallel_generator: + if isinstance(generated_item, str): + final_node_id = generated_item else: - yield item + yield generated_item if not final_node_id: break @@ -395,6 +437,7 @@ def _run_parallel_branches( edge_mappings: list[GraphEdge], in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -438,6 +481,7 @@ def _run_parallel_branches( "parallel_start_node_id": edge.target_node_id, "parent_parallel_id": in_parallel_id, "parent_parallel_start_node_id": parallel_start_node_id, + "handle_exceptions": handle_exceptions, }, ) @@ -481,6 +525,7 @@ def _run_parallel_node( parallel_start_node_id: str, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> None: """ Run parallel nodes @@ -502,6 +547,7 @@ def _run_parallel_node( in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: @@ -542,12 +588,13 @@ def _run_parallel_node( def _run_node( self, - node_instance: BaseNode, + node_instance: BaseNode[BaseNodeData], route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -567,135 +614,208 @@ def _run_node( ) db.session.close() + max_retries = node_instance.node_data.retry_config.max_retries + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + retries = 0 + should_continue_retry = True + while should_continue_retry and retries <= max_retries: + try: + # run node + retry_start_at = datetime.now(UTC).replace(tzinfo=None) + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and run_result.outputs + and not node_instance.should_continue_on_error + ): + run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + if node_instance.should_retry and retries < max_retries: + retries += 1 + route_node_state.node_run_result = run_result + yield NodeRunRetryEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=run_result.error or "Unknown error", + retry_index=retries, + start_at=retry_start_at, + ) + time.sleep(retry_interval) + continue + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, + ) + route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) - try: - # run node - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + + # When setting metadata, convert to dict first + if not run_result.metadata: + run_result.metadata = {} - yield item - else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result - route_node_state.set_finished(run_result=run_result) + if parallel_id and parallel_start_node_id: + metadata_dict = dict(run_result.metadata) + metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + run_result.metadata = metadata_dict + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False - if run_result.status == WorkflowNodeExecutionStatus.FAILED: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - # plus state total_tokens - self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] - ) - - if run_result.llm_usage: - # use the latest usage - self.graph_runtime_state.llm_usage += run_result.llm_usage - - # append node output variables to variable pool - if run_result.outputs: - for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node_instance.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, - ) - - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} - - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id - if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( - parent_parallel_start_node_id - ) - - yield NodeRunSucceededEvent( + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - - break - elif isinstance(item, RunStreamChunkEvent): - yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif isinstance(item, RunRetrieverResourceEvent): - yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - except GenerateTaskStoppedError: - # trigger node run failed event - route_node_state.status = RouteNodeState.Status.FAILED - route_node_state.failed_reason = "Workflow stopped." - yield NodeRunFailedEvent( - error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - return - except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") - raise e - finally: - db.session.close() + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed") + raise e + finally: + db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ @@ -735,6 +855,56 @@ def create_copy(self): new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) return new_instance + def _handle_continue_on_error( + self, + node_instance: BaseNode[BaseNodeData], + error_result: NodeRunResult, + variable_pool: VariablePool, + handle_exceptions: list[str] = [], + ) -> NodeRunResult: + """ + handle continue on error when self._should_continue_on_error is True + + + :param error_result (NodeRunResult): error run result + :param variable_pool (VariablePool): variable pool + :return: excption run result + """ + # add error message and error type to variable pool + variable_pool.add([node_instance.node_id, "error_message"], error_result.error) + variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + # add error message to handle_exceptions + handle_exceptions.append(error_result.error or "") + node_error_args: dict[str, Any] = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": error_result.error, + "inputs": error_result.inputs, + "metadata": { + NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + }, + } + + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + return NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + ) + elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node_instance.node_id): + node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED + return NodeRunResult( + **node_error_args, + outputs={ + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + ) + return error_result + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 8c78016f09a334..7d652d39f70ef4 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -147,14 +147,21 @@ def _recursive_fetch_answer_dependencies( reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.VARIABLE_ASSIGNER, - }: + source_node_data = node_id_config_mapping[source_node_id].get("data", {}) + if ( + source_node_type + in { + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, + NodeType.VARIABLE_ASSIGNER, + } + or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH + ): answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 8a768088da660e..40213bd151f7af 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -6,6 +6,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, + NodeRunExceptionEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -50,7 +51,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat for _ in stream_out_answer_node_ids: yield event - elif isinstance(event, NodeRunSucceededEvent): + elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent): yield event if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # update self.route_position after all stream event finished @@ -59,11 +60,10 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] - # remove unreachable nodes self._remove_unreachable_nodes(event) # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) + yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event)) else: yield event @@ -130,7 +130,7 @@ def _generate_stream_outputs_when_node_finished( node_type=event.node_type, node_data=event.node_data, chunk_content=text, - from_variable_selector=value_selector, + from_variable_selector=list(value_selector), route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 36c3fe180a9cb2..8ffb487ec108f8 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,10 +1,13 @@ +import logging from abc import ABC, abstractmethod from collections.abc import Generator from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +logger = logging.getLogger(__name__) + class StreamProcessor(ABC): def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: @@ -16,7 +19,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -29,15 +32,24 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: return if run_result.edge_source_handle: - reachable_node_ids = [] - unreachable_first_node_ids = [] + reachable_node_ids: list[str] = [] + unreachable_first_node_ids: list[str] = [] + if finished_node_id not in self.graph.edge_mapping: + logger.warning(f"node {finished_node_id} has no edge mapping") + return for edge in self.graph.edge_mapping[finished_node_id]: if ( edge.run_condition and edge.run_condition.branch_identify and run_result.edge_source_handle == edge.run_condition.branch_identify ): - reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + # remove unreachable nodes + # FIXME: because of the code branch can combine directly, so for answer node + # we remove the node maybe shortcut the answer node, so comment this code for now + # there is not effect on the answer node and the workflow, when we have a better solution + # we can open this code. Issues: #11542 #9560 #10638 #10564 + + # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: unreachable_first_node_ids.append(edge.target_node_id) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index fb50fbd6e863fa..6bf8899f5d698b 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,13 +1,137 @@ +import json from abc import ABC -from typing import Optional +from enum import StrEnum +from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, model_validator + +from core.workflow.nodes.base.exc import DefaultValueTypeError +from core.workflow.nodes.enums import ErrorStrategy + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +NumberType = Union[int, float] + + +class DefaultValue(BaseModel): + value: Any + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str) -> Any: + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + """Unified array type validation""" + # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> "DefaultValue": + if self.type is None: + raise DefaultValueTypeError("type field is required") + + # Type validation configuration + type_validators = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator: dict[str, Any] = type_validators.get(self.type, {}) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self + + +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + error_strategy: Optional[ErrorStrategy] = None + default_value: Optional[list[DefaultValue]] = None version: str = "1" + retry_config: RetryConfig = RetryConfig() + + @property + def default_value_dict(self): + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} class BaseIterationNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py new file mode 100644 index 00000000000000..aeecf406403e6d --- /dev/null +++ b/api/core/workflow/nodes/base/exc.py @@ -0,0 +1,10 @@ +class BaseNodeError(ValueError): + """Base class for node errors.""" + + pass + + +class DefaultValueTypeError(BaseNodeError): + """Raised when the default value type is invalid.""" + + pass diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index d0fbed31cd1e20..b799e7426616e7 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from models.workflow import WorkflowNodeExecutionStatus @@ -75,6 +75,7 @@ def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: result = NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), + error_type="WorkflowNodeError", ) if isinstance(result, NodeRunResult): @@ -137,3 +138,21 @@ def node_type(self) -> NodeType: :return: """ return self._node_type + + @property + def should_continue_on_error(self) -> bool: + """judge if should continue on error + + Returns: + bool: if should continue on error + """ + return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + + @property + def should_retry(self) -> bool: + """judge if should retry + + Returns: + bool: if should retry + """ + return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index ce283e38ec9b12..2f82bf8c382b55 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Optional from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -59,24 +59,25 @@ def _run(self) -> NodeRunResult: ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ + ) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _check_string(self, value: str, variable: str) -> str: + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, str): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: raise OutputValidationError( @@ -86,18 +87,17 @@ def _check_string(self, value: str, variable: str) -> str: return value.replace("\x00", "") - def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, int | float): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: raise OutputValidationError( @@ -116,18 +116,16 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") - transformed_result = {} + transformed_result: dict[str, Any] = {} if output_schema is None: # validate output thought instance type for output_name, output_value in result.items(): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index e78183baf12389..a4540358883210 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] - children: Optional[dict[str, "Output"]] = None + children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index d490a2eb03aff9..0b1dc611c59da2 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,17 +1,15 @@ import csv import io import json +import logging +import os +import tempfile +from typing import cast import docx import pandas as pd import pypdfium2 # type: ignore import yaml # type: ignore -from unstructured.partition.api import partition_via_api -from unstructured.partition.email import partition_email -from unstructured.partition.epub import partition_epub -from unstructured.partition.msg import partition_msg -from unstructured.partition.ppt import partition_ppt -from unstructured.partition.pptx import partition_pptx from configs import dify_config from core.file import File, FileTransferMethod, file_manager @@ -26,6 +24,8 @@ from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError +logger = logging.getLogger(__name__) + class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): """ @@ -160,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) except (UnicodeDecodeError, yaml.YAMLError) as e: raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e @@ -181,10 +181,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes) -> str: + """ + Extract text from a DOC/DOCX file. + For now support only paragraph and table add more if needed + """ try: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) - return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + text = [] + # Process paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text.append(paragraph.text) + + # Process tables + for table in doc.tables: + # Table header + try: + # table maybe cause errors so ignore it. + if len(table.rows) > 0 and table.rows[0].cells is not None: + # Check if any cell in the table has text + has_content = False + for row in table.rows: + if any(cell.text.strip() for cell in row.cells): + has_content = True + break + + if has_content: + markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" + for row in table.rows[1:]: + markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + text.append(markdown_table) + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue + + return "\n".join(text) except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e @@ -197,9 +230,9 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError("Missing URL for remote file") response = ssrf_proxy.get(file.remote_url) response.raise_for_status() - return response.content + return cast(bytes, response.content) else: - return file_manager.download(file) + return cast(bytes, file_manager.download(file)) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e @@ -254,6 +287,8 @@ def _extract_text_from_excel(file_content: bytes) -> str: def _extract_text_from_ppt(file_content: bytes) -> str: + from unstructured.partition.ppt import partition_ppt + try: with io.BytesIO(file_content) as file: elements = partition_ppt(file=file) @@ -263,15 +298,24 @@ def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str: + from unstructured.partition.api import partition_via_api + from unstructured.partition.pptx import partition_pptx + try: - with io.BytesIO(file_content) as file: - if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: - elements = partition_via_api( - file=file, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, - ) - else: + if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: + with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: + temp_file.write(file_content) + temp_file.flush() + with open(temp_file.name, "rb") as file: + elements = partition_via_api( + file=file, + metadata_filename=temp_file.name, + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY, + ) + os.unlink(temp_file.name) + else: + with io.BytesIO(file_content) as file: elements = partition_pptx(file=file) return "\n".join([getattr(element, "text", "") for element in elements]) except Exception as e: @@ -279,6 +323,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_epub(file_content: bytes) -> str: + from unstructured.partition.epub import partition_epub + try: with io.BytesIO(file_content) as file: elements = partition_epub(file=file) @@ -288,6 +334,8 @@ def _extract_text_from_epub(file_content: bytes) -> str: def _extract_text_from_eml(file_content: bytes) -> str: + from unstructured.partition.email import partition_email + try: with io.BytesIO(file_content) as file: elements = partition_email(file=file) @@ -297,6 +345,8 @@ def _extract_text_from_eml(file_content: bytes) -> str: def _extract_text_from_msg(file_content: bytes) -> str: + from unstructured.partition.msg import partition_msg + try: with io.BytesIO(file_content) as file: elements = partition_msg(file=file) diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index ea8b6b50420c99..b3678a82b73959 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -67,7 +67,7 @@ def extract_stream_variable_selector_from_node_data( and node_type == NodeType.LLM.value and variable_selector.value_selector[1] == "text" ): - value_selectors.append(variable_selector.value_selector) + value_selectors.append(list(variable_selector.value_selector)) return value_selectors @@ -119,8 +119,7 @@ def _recursive_fetch_end_dependencies( current_node_id: str, end_node_id: str, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], ) -> None: """ @@ -135,6 +134,8 @@ def _recursive_fetch_end_dependencies( reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { NodeType.IF_ELSE.value, diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 1aecf863ac5fb9..a770eb951f6c8c 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -23,7 +23,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.has_output = False - self.output_node_ids = set() + self.output_node_ids: set[str] = set() def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 44be403ee6ab91..7970a49aa42df4 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -22,3 +22,17 @@ class NodeType(StrEnum): VARIABLE_ASSIGNER = "assigner" DOCUMENT_EXTRACTOR = "document-extractor" LIST_OPERATOR = "list-operator" + + +class ErrorStrategy(StrEnum): + FAIL_BRANCH = "fail-branch" + DEFAULT_VALUE = "default-value" + + +class FailBranchSourceHandle(StrEnum): + FAILED = "fail-branch" + SUCCESS = "success-branch" + + +CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] +RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py index 5e3b31e48baa9e..08c47d5e57387b 100644 --- a/api/core/workflow/nodes/event/__init__.py +++ b/api/core/workflow/nodes/event/__init__.py @@ -1,4 +1,10 @@ -from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .event import ( + ModelInvokeCompletedEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunRetryEvent, + RunStreamChunkEvent, +) from .types import NodeEvent __all__ = [ @@ -6,5 +12,6 @@ "NodeEvent", "RunCompletedEvent", "RunRetrieverResourceEvent", + "RunRetryEvent", "RunStreamChunkEvent", ] diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index b7034561bf6713..9fea3fbda3141f 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,7 +1,10 @@ +from datetime import datetime + from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel): text: str usage: LLMUsage finish_reason: str | None = None + + +class RunRetryEvent(BaseModel): + """Node Run Retry event""" + + error: str = Field(..., description="error") + retry_index: int = Field(..., description="Retry attempt number") + start_at: datetime = Field(..., description="Retry start time") + + +class SingleStepRetryEvent(NodeRunResult): + """Single step retry event""" + + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY + + elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index 7a5ab7dbc1c1fa..a815f277becf9b 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError): class ResponseSizeError(HttpRequestNodeError): """Raised when the response size exceeds the allowed threshold.""" + + +class RequestBodyError(HttpRequestNodeError): + """Raised when the request body is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 22ad2a39f62fa4..cdfdc6e6d51b77 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -21,7 +21,9 @@ from .exc import ( AuthorizationConfigError, FileFetchError, + HttpRequestNodeError, InvalidHttpMethodError, + RequestBodyError, ResponseSizeError, ) @@ -36,7 +38,7 @@ class Executor: method: Literal["get", "head", "post", "put", "delete", "patch"] url: str - params: Mapping[str, str] | None + params: list[tuple[str, str]] | None content: str | bytes | None data: Mapping[str, Any] | None files: Mapping[str, tuple[str | None, bytes, str]] | None @@ -44,6 +46,7 @@ class Executor: headers: dict[str, str] auth: HttpRequestNodeAuthorization timeout: HttpRequestNodeTimeout + max_retries: int boundary: str @@ -53,6 +56,7 @@ def __init__( node_data: HttpRequestNodeData, timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, + max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -66,12 +70,13 @@ def __init__( self.method = node_data.method self.auth = node_data.authorization self.timeout = timeout - self.params = {} + self.params = [] self.headers = {} self.content = None self.files = None self.data = None self.json = None + self.max_retries = max_retries # init template self.variable_pool = variable_pool @@ -88,14 +93,48 @@ def _init_url(self): self.url = self.variable_pool.convert_template(self.node_data.url).text def _init_params(self): - params = _plain_text_to_dict(self.node_data.params) - for key in params: - params[key] = self.variable_pool.convert_template(params[key]).text - self.params = params + """ + Almost same as _init_headers(), difference: + 1. response a list tuple to support same key, like 'aa=1&aa=2' + 2. param value may have '\n', we need to splitlines then extract the variable value. + """ + result = [] + for line in self.node_data.params.splitlines(): + if not (line := line.strip()): + continue + + key, *value = line.split(":", 1) + if not (key := key.strip()): + continue + + value_str = value[0].strip() if value else "" + result.append( + (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) + ) + + self.params = result def _init_headers(self): + """ + Convert the header string of frontend to a dictionary. + + Each line in the header string represents a key-value pair. + Keys and values are separated by ':'. + Empty values are allowed. + + Examples: + 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} + 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} + 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} + + """ headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = _plain_text_to_dict(headers) + self.headers = { + key.strip(): (value[0].strip() if value else "") + for line in headers.splitlines() + if line.strip() + for key, *value in [line.split(":", 1)] + } def _init_body(self): body = self.node_data.body @@ -105,13 +144,19 @@ def _init_body(self): case "none": self.content = "" case "raw-text": + if len(data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") self.content = self.variable_pool.convert_template(data[0].value).text case "json": + if len(data) != 1: + raise RequestBodyError("json body type should have exactly one item") json_string = self.variable_pool.convert_template(data[0].value).text json_object = json.loads(json_string, strict=False) self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": + if len(data) != 1: + raise RequestBodyError("binary body type should have exactly one item") file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: @@ -137,9 +182,10 @@ def _init_body(self): self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } + files: dict[str, Any] = {} files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items()} + files = {k: variable.value for k, variable in files.items() if variable is not None} files = { k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") for k, v in files.items() @@ -206,11 +252,15 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: "params": self.params, "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "follow_redirects": True, + "max_retries": self.max_retries, } # request_args = {k: v for k, v in request_args.items() if v is not None} - - response = getattr(ssrf_proxy, self.method)(**request_args) - return response + try: + response = getattr(ssrf_proxy, self.method)(**request_args) + except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: + raise HttpRequestNodeError(str(e)) + # FIXME: fix type ignore, this maybe httpx type issue + return response # type: ignore def invoke(self) -> Response: # assemble headers @@ -252,66 +302,41 @@ def to_log(self): continue raw += f"{k}: {v}\r\n" - body = "" + body_string = "" if self.files: for k, v in self.files.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body += f"{v[1]}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body_string += f"{v[1]}\r\n" + body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: if isinstance(self.content, str): - body = self.content + body_string = self.content elif isinstance(self.content, bytes): - body = self.content.decode("utf-8", errors="replace") + body_string = self.content.decode("utf-8", errors="replace") elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body = urlencode(self.data) + body_string = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": for key, value in self.data.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body += f"{value}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body_string += f"{value}\r\n" + body_string += f"--{boundary}--\r\n" elif self.json: - body = json.dumps(self.json) + body_string = json.dumps(self.json) elif self.node_data.body.type == "raw-text": - body = self.node_data.body.data[0].value - if body: - raw += f"Content-Length: {len(body)}\r\n" + if len(self.node_data.body.data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") + body_string = self.node_data.body.data[0].value + if body_string: + raw += f"Content-Length: {len(body_string)}\r\n" raw += "\r\n" # Empty line between headers and body - raw += body + raw += body_string return raw -def _plain_text_to_dict(text: str, /) -> dict[str, str]: - """ - Convert a string of key-value pairs to a dictionary. - - Each line in the input string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - Args: - convert_text (str): The input string to convert. - - Returns: - dict[str, str]: A dictionary of key-value pairs. - """ - return { - key.strip(): (value[0].strip() if value else "") - for line in text.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _generate_random_string(n: int) -> str: """ Generate a random string of lowercase ASCII letters. diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 2a92a16ede84e0..861119f26cb088 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,6 +1,7 @@ import logging +import mimetypes from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod @@ -19,7 +20,7 @@ HttpRequestNodeTimeout, Response, ) -from .exc import HttpRequestNodeError +from .exc import HttpRequestNodeError, RequestBodyError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -35,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): _node_type = NodeType.HTTP_REQUEST @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { "type": "http-request", "config": { @@ -51,6 +52,11 @@ def get_default_config(cls, filters: dict | None = None) -> dict: "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, + "retry_config": { + "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, + "retry_interval": 0.5 * (2**2), + "retry_enabled": True, + }, } def _run(self) -> NodeRunResult: @@ -60,11 +66,27 @@ def _run(self) -> NodeRunResult: node_data=self.node_data, timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, + max_retries=0, ) process_data["request"] = http_executor.to_log() response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) + if not response.response.is_success and (self.should_continue_on_error or self.should_retry): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + error=f"Request failed with status code {response.status_code}", + error_type="HTTPResponseCodeError", + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ @@ -83,6 +105,7 @@ def _run(self) -> NodeRunResult: status=WorkflowNodeExecutionStatus.FAILED, error=str(e), process_data=process_data, + error_type=type(e).__name__, ) @staticmethod @@ -113,9 +136,13 @@ def _extract_variable_selector_to_variable_mapping( data = node_data.body.data match body_type: case "binary": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selector = data[0].file selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) case "json" | "raw-text": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selectors += variable_template_parser.extract_selectors_from_template(data[0].key) selectors += variable_template_parser.extract_selectors_from_template(data[0].value) case "x-www-form-urlencoded": @@ -133,27 +160,31 @@ def _extract_variable_selector_to_variable_mapping( ) mapping = {} - for selector in selectors: - mapping[node_id + "." + selector.variable] = selector.value_selector + for selector_iter in selectors: + mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector return mapping def extract_files(self, url: str, response: Response) -> list[File]: """ - Extract files from response + Extract files from response by checking both Content-Type header and URL """ files = [] is_file = response.is_file content_type = response.content_type content = response.content - if is_file and content_type: + if is_file: + # Guess file extension from URL or Content-Type header + filename = url.split("?")[0].split("/")[-1] or "" + mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, file_binary=content, - mimetype=content_type, + mimetype=mime_type, ) mapping = { diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index bba6ac20d3712b..f1289558fffa82 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder +from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.workflow.entities.node_entities import ( NodeRunMetadataKey, NodeRunResult, @@ -75,12 +75,15 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. """ - iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - if not iterator_list_segment: - raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") + if not variable: + raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - if len(iterator_list_segment.value) == 0: + if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): + raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") + + if isinstance(variable, NoneVariable) or len(variable.value) == 0: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -89,7 +92,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: ) return - iterator_list_value = iterator_list_segment.to_object() + iterator_list_value = variable.to_object() if not isinstance(iterator_list_value, list): raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") @@ -155,28 +158,31 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, index=0, pre_iteration_output=None, + duration=None, ) iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: if self.node_data.is_parallel: futures: list[Future] = [] - q = Queue() - thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + q: Queue = Queue() + thread_pool = GraphEngineThreadPool( + max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT + ) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( self._run_single_iter_parallel, - current_app._get_current_object(), - q, - iterator_list_value, - inputs, - outputs, - start_at, - graph_engine, - iteration_graph, - index, - item, - iter_run_map, + flask_app=current_app._get_current_object(), # type: ignore + q=q, + iterator_list_value=iterator_list_value, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine, + iteration_graph=iteration_graph, + index=index, + item=item, + iter_run_map=iter_run_map, ) future.add_done_callback(thread_pool.task_done_callback) futures.append(future) @@ -208,17 +214,22 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: else: for _ in range(len(iterator_list_value)): yield from self._run_single_iter( - iterator_list_value, - variable_pool, - inputs, - outputs, - start_at, - graph_engine, - iteration_graph, - iter_run_map, + iterator_list_value=iterator_list_value, + variable_pool=variable_pool, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine, + iteration_graph=iteration_graph, + iter_run_map=iter_run_map, ) if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] + + # Flatten the list of lists + if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): + outputs = [item for sublist in outputs for item in sublist] + yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -226,7 +237,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) @@ -234,8 +245,11 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": jsonable_encoder(outputs)}, - metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, + outputs={"output": outputs}, + metadata={ + NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + }, ) ) except IterationNodeError as e: @@ -248,7 +262,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=str(e), @@ -280,7 +294,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ - variable_mapping = { + variable_mapping: dict[str, Sequence[str]] = { f"{node_id}.input_selector": node_data.iterator_selector, } @@ -308,7 +322,7 @@ def _extract_variable_selector_to_variable_mapping( sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=sub_node_config ) - sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: sub_node_variable_mapping = {} @@ -329,8 +343,12 @@ def _extract_variable_selector_to_variable_mapping( return variable_mapping def _handle_event_metadata( - self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str - ) -> NodeRunStartedEvent | BaseNodeEvent: + self, + *, + event: BaseNodeEvent | InNodeEvent, + iter_run_index: int, + parallel_mode_run_id: str | None, + ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: """ add iteration metadata to event. """ @@ -343,21 +361,25 @@ def _handle_event_metadata( metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - if self.node_data.is_parallel: - metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id - else: - metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + metadata = { + **metadata, + NodeRunMetadataKey.ITERATION_ID: self.node_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID + if self.node_data.is_parallel + else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id + if self.node_data.is_parallel + else iter_run_index, + } event.route_node_state.node_run_result.metadata = metadata return event def _run_single_iter( self, - iterator_list_value: list[str], + *, + iterator_list_value: Sequence[str], variable_pool: VariablePool, - inputs: dict[str, list], + inputs: Mapping[str, list], outputs: list, start_at: datetime, graph_engine: "GraphEngine", @@ -373,12 +395,12 @@ def _run_single_iter( try: rst = graph_engine.run() # get current iteration index - current_index = variable_pool.get([self.node_id, "index"]).value + index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(index_variable, IntegerVariable): + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") + current_index = index_variable.value iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" next_index = int(current_index) + 1 - - if current_index is None: - raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") for event in rst: if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: event.in_iteration_id = self.node_id @@ -391,7 +413,9 @@ def _run_single_iter( continue if isinstance(event, NodeRunSucceededEvent): - yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + yield self._handle_event_metadata( + event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id + ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed @@ -404,7 +428,7 @@ def _run_single_iter( parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, @@ -417,7 +441,7 @@ def _run_single_iter( iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, @@ -429,9 +453,11 @@ def _run_single_iter( ) ) return - else: - event = cast(InNodeEvent, event) - metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, InNodeEvent): + # event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata( + event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id + ) if isinstance(event, NodeRunFailedEvent): if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( @@ -513,7 +539,7 @@ def _run_single_iter( iteration_node_data=self.node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, - pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + pre_iteration_output=current_iteration_output or None, duration=duration, ) @@ -540,10 +566,11 @@ def _run_single_iter( def _run_single_iter_parallel( self, + *, flask_app: Flask, q: Queue, - iterator_list_value: list[str], - inputs: dict[str, list], + iterator_list_value: Sequence[str], + inputs: Mapping[str, list], outputs: list, start_at: datetime, graph_engine: "GraphEngine", @@ -551,7 +578,7 @@ def _run_single_iter_parallel( index: int, item: Any, iter_run_map: dict[str, float], - ) -> Generator[NodeEvent | InNodeEvent, None, None]: + ): """ run single iteration in parallel mode """ @@ -573,3 +600,4 @@ def _run_single_iter_parallel( parallel_mode_run_id=parallel_mode_run_id, ): q.put(event) + graph_engine.graph_runtime_state.total_tokens += graph_engine_copy.graph_runtime_state.total_tokens diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8c5a9b5ecb8708..bfd93c074dd6d5 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -70,7 +70,20 @@ def _run(self) -> NodeRunResult: except KnowledgeRetrievalNodeError as e: logger.warning("Error when running knowledge retrieval node") - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + # Temporary handle all exceptions from DatasetRetrieval class here. + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -134,6 +147,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config is None: + raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_model: reranking_model = { @@ -144,6 +159,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { @@ -160,18 +177,20 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: reranking_model = None weights = None all_documents = dataset_retrieval.multiple_retrieve( - self.app_id, - self.tenant_id, - self.user_id, - self.user_from.value, - available_datasets, - query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, + app_id=self.app_id, + tenant_id=self.tenant_id, + user_id=self.user_id, + user_from=self.user_from.value, + available_datasets=available_datasets, + query=query, + top_k=node_data.multiple_retrieval_config.top_k, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, + reranking_mode=node_data.multiple_retrieval_config.reranking_mode, + reranking_model=reranking_model, + weights=weights, + reranking_enable=node_data.multiple_retrieval_config.reranking_enable, ) dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] @@ -192,7 +211,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} + document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: document_score_list = {} @@ -247,7 +266,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( - retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + retrieval_resource_list, + key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + reverse=True, ) position = 1 for item in retrieval_resource_list: @@ -282,6 +303,8 @@ def _fetch_model_config( :param node_data: node data :return: """ + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") model_name = node_data.single_retrieval_config.model.name provider_name = node_data.single_retrieval_config.model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 79066cece4f93c..432c57294ecbe9 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal, Union +from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_type = NodeType.LIST_OPERATOR def _run(self): - inputs = {} - process_data = {} - outputs = {} + inputs: dict[str, list] = {} + process_data: dict[str, list] = {} + outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) if variable is None: @@ -93,6 +93,8 @@ def _run(self): def _apply_filter( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + filter_func: Callable[[Any], bool] + result: list[Any] = [] for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): @@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str raise InvalidKeyError(f"Invalid key: {key}") -def _contains(value: str): +def _contains(value: str) -> Callable[[str], bool]: return lambda x: value in x -def _startswith(value: str): +def _startswith(value: str) -> Callable[[str], bool]: return lambda x: x.startswith(value) -def _endswith(value: str): +def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str): +def _is(value: str) -> Callable[[str], bool]: return lambda x: x is value -def _in(value: str | Sequence[str]): +def _in(value: str | Sequence[str]) -> Callable[[str], bool]: return lambda x: x in value -def _eq(value: int | float): +def _eq(value: int | float) -> Callable[[int | float], bool]: return lambda x: x == value -def _ne(value: int | float): +def _ne(value: int | float) -> Callable[[int | float], bool]: return lambda x: x != value -def _lt(value: int | float): +def _lt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x < value -def _le(value: int | float): +def _le(value: int | float) -> Callable[[int | float], bool]: return lambda x: x <= value -def _gt(value: int | float): +def _gt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x > value -def _ge(value: int | float): +def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value @@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 19a66087f7d175..505068104c2c2d 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -50,6 +50,7 @@ def convert_none_jinja2_variables(cls, v: Any): class LLMNodeChatModelMessage(ChatModelMessage): + text: str = "" jinja2_text: Optional[str] = None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 8ab0d8b2eb20ad..6909b30c9e82ca 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: - node_inputs = None + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs: Optional[dict[str, Any]] = None process_data = None try: @@ -145,8 +145,8 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( - user_query=query, - user_files=files, + sys_query=query, + sys_files=files, context=context, memory=memory, model_config=model_config, @@ -193,9 +193,9 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, ) ) - return except Exception as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -205,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] process_data=process_data, ) ) - return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -301,7 +300,7 @@ def _transform_chat_messages( return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables = {} + variables: dict[str, Any] = {} if not node_data.prompt_config: return variables @@ -318,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str: """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return input_dict["content"] + return str(input_dict["content"]) # else, parse the dict try: @@ -544,8 +543,8 @@ def _fetch_memory( def _fetch_prompt_messages( self, *, - user_query: str | None = None, - user_files: Sequence["File"], + sys_query: str | None = None, + sys_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -556,12 +555,13 @@ def _fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - prompt_messages = [] + # FIXME: fix the type error cause prompt_messages is type quick a few times + prompt_messages: list[Any] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -580,14 +580,14 @@ def _fetch_prompt_messages( prompt_messages.extend(memory_messages) # Add current query to the prompt messages - if user_query: + if sys_query: message = LLMNodeChatModelMessage( - text=user_query, + text=sys_query, role=PromptMessageRole.USER, edition_type="basic", ) prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -615,24 +615,46 @@ def _fetch_prompt_messages( ) # Insert histories into the prompt prompt_content = prompt_messages[0].content - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) + # For issue #11247 - Check if prompt content is a string or a list + prompt_content_type = type(prompt_content) + if prompt_content_type == str: + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif prompt_content_type == list: + for content_item in prompt_content: + if content_item.type == PromptMessageContentType.TEXT: + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content + raise ValueError("Invalid prompt content type") # Add current query to the prompt message - if user_query: - prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) - prompt_messages[0].content = prompt_content + if sys_query: + if prompt_content_type == str: + prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) + prompt_messages[0].content = prompt_content + elif prompt_content_type == list: + for content_item in prompt_content: + if content_item.type == PromptMessageContentType.TEXT: + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") else: raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - if vision_enabled and user_files: + # The sys_files will be deprecated later + if vision_enabled and sys_files: file_prompts = [] - for file in user_files: + for file in sys_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) @@ -642,7 +664,7 @@ def _fetch_prompt_messages( else: prompt_messages.append(UserPromptMessage(content=file_prompts)) - # Filter prompt messages + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): @@ -760,7 +782,7 @@ def _extract_variable_selector_to_variable_mapping( else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - variable_mapping = {} + variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector @@ -826,6 +848,68 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: }, } + def _handle_list_messages( + self, + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, + ) -> Sequence[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -860,68 +944,6 @@ def _render_jinja2_message( return result_text -def _handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, -) -> Sequence[PromptMessage]: - prompt_messages = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinjia2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - if isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity ) -> int: @@ -958,7 +980,7 @@ def _handle_memory_chat_mode( memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: - memory_messages = [] + memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 6fdff966026b63..a366c287c2ac56 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self) -> LoopState: - return super()._run() + def _run(self) -> LoopState: # type: ignore + return super()._run() # type: ignore @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: @@ -28,7 +28,7 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: # TODO waiting for implementation return [ - Condition( + Condition( # type: ignore variable_selector=[node_id, "index"], comparison_operator="≤", value_type="value_selector", diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index a001b44dc7dfee..369eb13b04e8c4 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -25,7 +25,7 @@ def validate_name(cls, value) -> str: raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return value + return str(value) class ParameterExtractorNodeData(BaseNodeData): @@ -52,7 +52,7 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = {"type": "object", "properties": {}, "required": []} + parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 5b960ea6151954..9c88047f2c8e57 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode): Parameter Extractor Node. """ - _node_data_cls = ParameterExtractorNodeData + # FIXME: figure out why here is different from super class + _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR _model_instance: Optional[ModelInstance] = None @@ -179,6 +180,15 @@ def _run(self): error=str(e), metadata={}, ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, + error=str(e), + metadata={}, + ) error = None @@ -244,6 +254,9 @@ def _invoke( # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + if text is None: + text = "" + return text, usage, tool_call def _generate_function_call_prompt( @@ -596,9 +609,10 @@ def extract_json(text): json_str = extract_json(result[idx:]) if json_str: try: - return json.loads(json_str) + return cast(dict, json.loads(json_str)) except Exception: pass + return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: """ @@ -607,13 +621,13 @@ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCal if not tool_call or not tool_call.function.arguments: return None - return json.loads(tool_call.function.arguments) + return cast(dict, json.loads(tool_call.function.arguments)) def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ Generate default result. """ - result = {} + result: dict[str, Any] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0 @@ -763,7 +777,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, + node_data: ParameterExtractorNodeData, # type: ignore ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -772,6 +786,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index e603add1704544..6c3155ac9a54e3 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,3 +1,5 @@ +from typing import Any + FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. @@ -35,7 +37,7 @@ """ # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ +FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ { "user": { "query": "What is the weather today in SF?", diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index e855ab2d2b0659..0ec44eefacf52f 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,8 @@ import json -import logging from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole @@ -36,12 +34,9 @@ QUESTION_CLASSIFIER_USER_PROMPT_3, ) -if TYPE_CHECKING: - from core.file import File - class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData + _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER def _run(self): @@ -63,7 +58,7 @@ def _run(self): node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text - files: Sequence[File] = ( + files = ( self._fetch_files( selector=node_data.vision.configs.variable_selector, ) @@ -86,37 +81,38 @@ def _run(self): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - user_query=query, + sys_query=query, memory=memory, model_config=model_config, - user_files=files, + sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], ) - # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - ) - result_text = "" usage = LLMUsage.empty_usage() finish_reason = None - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - category_name = node_data.classes[0].name - category_id = node_data.classes[0].id try: + # handle invoke result + generator = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + for event in generator: + if isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + + category_name = node_data.classes[0].name + category_id = node_data.classes[0].id result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) if "category_name" in result_text_json and "category_id" in result_text_json: @@ -127,10 +123,6 @@ def _run(self): if category_id_result in category_ids: category_name = classes_map[category_id_result] category_id = category_id_result - - except OutputParserError: - logging.exception(f"Failed to parse result text: {result_text}") - try: process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -139,7 +131,7 @@ def _run(self): "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } - outputs = {"class_name": category_name} + outputs = {"class_name": category_name, "class_id": category_id} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -154,7 +146,6 @@ def _run(self): }, llm_usage=usage, ) - except ValueError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -174,7 +165,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: QuestionClassifierNodeData, + node_data: Any, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -183,6 +174,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + node_data = cast(QuestionClassifierNodeData, node_data) variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 951e5330a324e8..01d07e494944b4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,6 @@ from collections.abc import Mapping, Sequence from typing import Any +from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session @@ -8,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine -from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -45,6 +45,8 @@ def _run(self) -> NodeRunResult: # get tool runtime try: + from core.tools.tool_manager import ToolManager + tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) @@ -56,6 +58,7 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to get tool runtime: {str(e)}", + error_type=type(e).__name__, ) # get parameters @@ -89,6 +92,17 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, + ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info, + }, + error=f"Failed to invoke tool: {str(e)}", + error_type="UnknownError", ) # convert tool messages @@ -129,7 +143,7 @@ def _generate_parameters( """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: @@ -219,6 +233,10 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE tool_file_id = url.split("/")[-1].split(".")[0] + try: + UUID(tool_file_id) + except ValueError: + raise ToolFileError(f"cannot extract tool file id from url {url}") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -247,9 +265,9 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ return "\n".join( [ - f"{message.message}" + str(message.message) if message.type == ToolInvokeMessage.MessageType.TEXT - else f"Link: {message.message}" + else f"Link: {str(message.message)}" for message in tool_response if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py index a1178fb0203593..f8dbedc2901c9f 100644 --- a/api/core/workflow/nodes/variable_assigner/common/exc.py +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -1,4 +1,4 @@ -class VariableOperatorNodeError(Exception): +class VariableOperatorNodeError(ValueError): """Base error type, don't use directly.""" pass diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 8eb4bd5c2da573..9acc76f326eec9 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -36,6 +36,8 @@ def _run(self) -> NodeRunResult: case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) + if income_value is None: + raise VariableOperatorNodeError("income value not found") updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d73c7442029225..0c4aae827c0a0f 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, cast from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() - process_data = {} + process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variables: list[Variable] = [] @@ -119,7 +119,7 @@ def _run(self) -> NodeRunResult: else: conversation_id = conversation_id.value common_helpers.update_conversation_variable( - conversation_id=conversation_id, + conversation_id=cast(str, conversation_id), variable=variable, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 811e40c11e5407..b14c6fafbd9fdc 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -129,11 +129,11 @@ def single_step_run( :return: """ # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: + workflow_graph = workflow.graph_dict + if not workflow_graph: raise ValueError("workflow graph not found") - nodes = graph.get("nodes") + nodes = workflow_graph.get("nodes") if not nodes: raise ValueError("nodes not found in workflow graph") @@ -196,7 +196,8 @@ def single_step_run( @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: - return WorkflowEntry._handle_special_values(value) + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod def _handle_special_values(value: Any) -> Any: @@ -208,10 +209,10 @@ def _handle_special_values(value: Any) -> Any: res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res = [] + res_list = [] for item in value: - res.append(WorkflowEntry._handle_special_values(item)) - return res + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list if isinstance(value, File): return value.to_dict() return value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 2b6a8dd3d0570d..881263171fa145 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -34,7 +34,6 @@ else --workers ${SERVER_WORKER_AMOUNT:-1} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - --preload \ app:app fi fi diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 24fa013697994c..8a677f6b6fc017 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids") + document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 1515661b2d45b8..5e7caf8cbed71e 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -8,18 +8,19 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender account = kwargs.get("account") - site = Site( - app_id=app.id, - title=app.name, - icon_type=app.icon_type, - icon=app.icon, - icon_background=app.icon_background, - default_language=account.interface_language, - customize_token_strategy="not_allow", - code=Site.generate_code(16), - created_by=app.created_by, - updated_by=app.updated_by, - ) + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) - db.session.add(site) - db.session.commit() + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 843a2320968ced..1ed37efba0b3be 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -44,7 +44,7 @@ def handle(sender, **kwargs): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 9c5955c8c5a1a5..f89fae24a56378 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,7 +8,10 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index de7c0f4dfeb74f..408ed31096d2a0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -8,16 +8,18 @@ def handle(sender, **kwargs): app = sender app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -37,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: - dataset_ids = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]: + dataset_ids: set[int] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 453395e8d7dc1c..7a31c82f6adbc2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: - dataset_ids = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: + dataset_ids: set[int] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(node_data.dataset_ids) + dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/tests/unit_tests/oss/local/__init__.py b/api/extensions/__init__.py similarity index 100% rename from api/tests/unit_tests/oss/local/__init__.py rename to api/extensions/__init__.py diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index de1cdfeb984e86..b7d412d68deda1 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -54,12 +54,14 @@ def pool_stat(): from extensions.ext_database import db engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue return { "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore } diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 9dbc4b93d46266..30f216ff95612b 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,8 @@ from datetime import timedelta import pytz -from celery import Celery, Task -from celery.schedules import crontab +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore from configs import dify_config from dify_app import DifyApp @@ -47,7 +47,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, - timezone=pytz.timezone(dify_config.LOG_TZ), + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) if dify_config.BROKER_USE_SSL: diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 9c3a663af417ae..26ff6427bef1cc 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress + from flask_compress import Compress # type: ignore compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index e293afa1115e8b..93842a303683bb 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -1,18 +1,5 @@ -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import MetaData - from dify_app import DifyApp - -POSTGRES_INDEXES_NAMING_CONVENTION = { - "ix": "%(column_0_label)s_idx", - "uq": "%(table_name)s_%(column_0_name)s_key", - "ck": "%(table_name)s_%(constraint_name)s_check", - "fk": "%(table_name)s_%(column_0_name)s_fkey", - "pk": "%(table_name)s_pkey", -} - -metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) -db = SQLAlchemy(metadata=metadata) +from models import db def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index eefdfd38236662..9566f430b647fe 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -3,4 +3,3 @@ def init_app(app: DifyApp): from events import event_handlers # noqa: F401 - from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 738d5c7bd2b2f5..e1c459e8c17fd0 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -1,14 +1,17 @@ import logging import os import sys +import uuid from logging.handlers import RotatingFileHandler +import flask + from configs import dify_config from dify_app import DifyApp def init_app(app: DifyApp): - log_handlers = [] + log_handlers: list[logging.Handler] = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -22,11 +25,14 @@ def init_app(app: DifyApp): ) # Always add StreamHandler to log to console - log_handlers.append(logging.StreamHandler(sys.stdout)) + sh = logging.StreamHandler(sys.stdout) + sh.addFilter(RequestIdFilter()) + log_formatter = logging.Formatter(fmt=dify_config.LOG_FORMAT) + sh.setFormatter(log_formatter) + log_handlers.append(sh) logging.basicConfig( level=dify_config.LOG_LEVEL, - format=dify_config.LOG_FORMAT, datefmt=dify_config.LOG_DATEFORMAT, handlers=log_handlers, force=True, @@ -43,4 +49,24 @@ def time_converter(seconds): return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() for handler in logging.root.handlers: - handler.formatter.converter = time_converter + if handler.formatter: + handler.formatter.converter = time_converter + + +def get_request_id(): + if getattr(flask.g, "request_id", None): + return flask.g.request_id + + new_uuid = uuid.uuid4().hex[:10] + flask.g.request_id = new_uuid + + return new_uuid + + +class RequestIdFilter(logging.Filter): + # This is a logging filter that makes the request ID available for use in + # the logging format. Note that we're checking if we're in a request + # context, as we may want to log things before Flask is fully loaded. + def filter(self, record): + record.req_id = get_request_id() if flask.has_request_context() else "" + return True diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index b2955307144d67..10fb89eb7370ee 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login +import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 468aedd47ea90b..9240ebe7fcba73 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ def init_app(self, app: Flask): match mail_type: case "resend": - import resend + import resend # type: ignore api_key = dify_config.RESEND_API_KEY if not api_key: @@ -48,9 +48,9 @@ def init_app(self, app: Flask): self._client = SMTPClient( server=dify_config.SMTP_SERVER, port=dify_config.SMTP_PORT, - username=dify_config.SMTP_USERNAME, - password=dify_config.SMTP_PASSWORD, - _from=dify_config.MAIL_DEFAULT_SEND_FROM, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 6d8f35c30d9c65..5f862181fa8540 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ def init_app(app: DifyApp): - import flask_migrate + import flask_migrate # type: ignore from extensions.ext_database import db diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index 3b895ac95b5029..514e0658257293 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app) + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 8016356a3e2961..3a74aace6a34cf 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,7 +6,7 @@ def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import openai import sentry_sdk - from langfuse import parse_error + from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException @@ -27,6 +27,7 @@ def before_send(event, hint): ignore_errors=[ HTTPException, ValueError, + FileNotFoundError, openai.APIStatusError, InvokeRateLimitError, parse_error.defaultErrorResponse, diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 6c30b7a257045a..588bdb2d2717e0 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Generator -from typing import Union +from collections.abc import Callable, Generator +from typing import Literal, Union, overload from flask import Flask @@ -9,23 +9,30 @@ from extensions.storage.base_storage import BaseStorage from extensions.storage.storage_type import StorageType +logger = logging.getLogger(__name__) -class Storage: - def __init__(self): - self.storage_runner = None +class Storage: def init_app(self, app: Flask): storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) with app.app_context(): self.storage_runner = storage_factory() @staticmethod - def get_storage_factory(storage_type: str) -> type[BaseStorage]: + def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]: match storage_type: case StorageType.S3: from extensions.storage.aws_s3_storage import AwsS3Storage return AwsS3Storage + case StorageType.OPENDAL: + from extensions.storage.opendal_storage import OpenDALStorage + + return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME) + case StorageType.LOCAL: + from extensions.storage.opendal_storage import OpenDALStorage + + return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH) case StorageType.AZURE_BLOB: from extensions.storage.azure_blob_storage import AzureBlobStorage @@ -62,18 +69,22 @@ def get_storage_factory(storage_type: str) -> type[BaseStorage]: from extensions.storage.supabase_storage import SupabaseStorage return SupabaseStorage - case StorageType.LOCAL | _: - from extensions.storage.local_fs_storage import LocalFsStorage - - return LocalFsStorage + case _: + raise ValueError(f"unsupported storage type {storage_type}") def save(self, filename, data): try: self.storage_runner.save(filename, data) except Exception as e: - logging.exception(f"Failed to save file {filename}") + logger.exception(f"Failed to save file {filename}") raise e + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: @@ -81,42 +92,42 @@ def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Genera else: return self.load_once(filename) except Exception as e: - logging.exception(f"Failed to load file {filename}") + logger.exception(f"Failed to load file {filename}") raise e def load_once(self, filename: str) -> bytes: try: return self.storage_runner.load_once(filename) except Exception as e: - logging.exception(f"Failed to load_once file {filename}") + logger.exception(f"Failed to load_once file {filename}") raise e def load_stream(self, filename: str) -> Generator: try: return self.storage_runner.load_stream(filename) except Exception as e: - logging.exception(f"Failed to load_stream file {filename}") + logger.exception(f"Failed to load_stream file {filename}") raise e def download(self, filename, target_filepath): try: self.storage_runner.download(filename, target_filepath) except Exception as e: - logging.exception(f"Failed to download file {filename}") + logger.exception(f"Failed to download file {filename}") raise e def exists(self, filename): try: return self.storage_runner.exists(filename) except Exception as e: - logging.exception(f"Failed to check file exists {filename}") + logger.exception(f"Failed to check file exists {filename}") raise e def delete(self, filename): try: return self.storage_runner.delete(filename) except Exception as e: - logging.exception(f"Failed to delete file {filename}") + logger.exception(f"Failed to delete file {filename}") raise e diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 58c917dbd386bc..00bf5d4f93ae3b 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 +import oss2 as aliyun_s3 # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -33,7 +33,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data = obj.read() + data: bytes = obj.read() return data def load_stream(self, filename: str) -> Generator: @@ -41,14 +41,14 @@ def load_stream(self, filename: str) -> Generator: while chunk := obj.read(4096): yield chunk - def download(self, filename, target_filepath): + def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) - def exists(self, filename): + def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) - def delete(self, filename): + def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) - def __wrapper_folder_filename(self, filename) -> str: + def __wrapper_folder_filename(self, filename: str) -> str: return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ab2d0fba3b19f3..7b6b2eedd62bf2 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,9 +1,9 @@ import logging from collections.abc import Generator -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -53,7 +53,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -67,7 +67,9 @@ def load_stream(self, filename: str) -> Generator: yield from response["Body"].iter_chunks() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") + raise FileNotFoundError("file not found") + elif "reached max retries" in str(ex): + raise ValueError("please do not request the same file too frequently") else: raise diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index b26caa8671b6df..2f8532f4f8f653 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,7 +27,7 @@ def load_once(self, filename: str) -> bytes: client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() + data: bytes = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: @@ -63,11 +63,11 @@ def _sync_client(self): sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, + account_name=self.account_name or "", + account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index e0d2140e91272c..b94efa08be7613 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials -from baidubce.bce_client_configuration import BceClientConfiguration -from baidubce.services.bos.bos_client import BosClient +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -36,7 +36,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) - return response.data.read() + data: bytes = response.data.read() + return data def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 50abab8537ffa4..0dedd7ff8cc325 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -7,9 +7,6 @@ class BaseStorage(ABC): """Interface for file storage.""" - def __init__(self): # noqa: B027 - pass - @abstractmethod def save(self, filename, data): raise NotImplementedError diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 26b662d2f04daf..705639f42e716f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage +from google.cloud import storage as google_cloud_storage # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - data = blob.download_as_bytes() + data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 20be70ef83dd7a..07f1d199701be4 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient +from obs import ObsClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -23,7 +23,7 @@ def save(self, filename, data): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) def load_once(self, filename: str) -> bytes: - data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py deleted file mode 100644 index 5a495ca4d41042..00000000000000 --- a/api/extensions/storage/local_fs_storage.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import shutil -from collections.abc import Generator -from pathlib import Path - -from flask import current_app - -from configs import dify_config -from extensions.storage.base_storage import BaseStorage - - -class LocalFsStorage(BaseStorage): - """Implementation for local filesystem storage.""" - - def __init__(self): - super().__init__() - folder = dify_config.STORAGE_LOCAL_PATH - if not os.path.isabs(folder): - folder = os.path.join(current_app.root_path, folder) - self.folder = folder - - def _build_filepath(self, filename: str) -> str: - """Build the full file path based on the folder and filename.""" - if not self.folder or self.folder.endswith("/"): - return self.folder + filename - else: - return self.folder + "/" + filename - - def save(self, filename, data): - filepath = self._build_filepath(filename) - folder = os.path.dirname(filepath) - os.makedirs(folder, exist_ok=True) - Path(os.path.join(os.getcwd(), filepath)).write_bytes(data) - - def load_once(self, filename: str) -> bytes: - filepath = self._build_filepath(filename) - if not os.path.exists(filepath): - raise FileNotFoundError("File not found") - return Path(filepath).read_bytes() - - def load_stream(self, filename: str) -> Generator: - filepath = self._build_filepath(filename) - if not os.path.exists(filepath): - raise FileNotFoundError("File not found") - with open(filepath, "rb") as f: - while chunk := f.read(4096): # Read in chunks of 4KB - yield chunk - - def download(self, filename, target_filepath): - filepath = self._build_filepath(filename) - if not os.path.exists(filepath): - raise FileNotFoundError("File not found") - shutil.copyfile(filepath, target_filepath) - - def exists(self, filename): - filepath = self._build_filepath(filename) - return os.path.exists(filepath) - - def delete(self, filename): - filepath = self._build_filepath(filename) - if os.path.exists(filepath): - os.remove(filepath) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py new file mode 100644 index 00000000000000..b78fc94dae7843 --- /dev/null +++ b/api/extensions/storage/opendal_storage.py @@ -0,0 +1,89 @@ +import logging +import os +from collections.abc import Generator +from pathlib import Path + +import opendal # type: ignore[import] +from dotenv import dotenv_values + +from extensions.storage.base_storage import BaseStorage + +logger = logging.getLogger(__name__) + + +def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"): + kwargs = {} + config_prefix = prefix + scheme.upper() + "_" + for key, value in os.environ.items(): + if key.startswith(config_prefix): + kwargs[key[len(config_prefix) :].lower()] = value + + file_env_vars: dict = dotenv_values(env_file_path) or {} + for key, value in file_env_vars.items(): + if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: + kwargs[key[len(config_prefix) :].lower()] = value + + return kwargs + + +class OpenDALStorage(BaseStorage): + def __init__(self, scheme: str, **kwargs): + kwargs = kwargs or _get_opendal_kwargs(scheme=scheme) + + if scheme == "fs": + root = kwargs.get("root", "storage") + Path(root).mkdir(parents=True, exist_ok=True) + + self.op = opendal.Operator(scheme=scheme, **kwargs) + logger.debug(f"opendal operator created with scheme {scheme}") + retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) + self.op = self.op.layer(retry_layer) + logger.debug("added retry layer to opendal operator") + + def save(self, filename: str, data: bytes) -> None: + self.op.write(path=filename, bs=data) + logger.debug(f"file {filename} saved") + + def load_once(self, filename: str) -> bytes: + if not self.exists(filename): + raise FileNotFoundError("File not found") + + content: bytes = self.op.read(path=filename) + logger.debug(f"file {filename} loaded") + return content + + def load_stream(self, filename: str) -> Generator: + if not self.exists(filename): + raise FileNotFoundError("File not found") + + batch_size = 4096 + file = self.op.open(path=filename, mode="rb") + while chunk := file.read(batch_size): + yield chunk + logger.debug(f"file {filename} loaded as stream") + + def download(self, filename: str, target_filepath: str): + if not self.exists(filename): + raise FileNotFoundError("File not found") + + with Path(target_filepath).open("wb") as f: + f.write(self.op.read(path=filename)) + logger.debug(f"file {filename} downloaded to {target_filepath}") + + def exists(self, filename: str) -> bool: + # FIXME this is a workaround for opendal python-binding do not have a exists method and no better + # error handler here when opendal python-binding has a exists method, we should use it + # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs + try: + res: bool = self.op.stat(path=filename).mode.is_file() + logger.debug(f"file {filename} checked") + return res + except Exception: + return False + + def delete(self, filename: str): + if self.exists(filename): + self.op.delete(path=filename) + logger.debug(f"file {filename} deleted") + return + logger.debug(f"file {filename} not found, skip delete") diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index b59f83b8de90bf..82829f7fd50d65 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -27,7 +27,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index e7fa405afacdcc..0a891e36cf17a9 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -9,6 +9,7 @@ class StorageType(StrEnum): HUAWEI_OBS = "huawei-obs" LOCAL = "local" OCI_STORAGE = "oci-storage" + OPENDAL = "opendal" S3 = "s3" TENCENT_COS = "tencent-cos" VOLCENGINE_TOS = "volcengine-tos" diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 9f7c69a9ae6312..711c3f72117c86 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -32,7 +32,7 @@ def save(self, filename, data): self.client.storage.from_(self.bucket_name).upload(filename, data) def load_once(self, filename: str) -> bytes: - content = self.client.storage.from_(self.bucket_name).download(filename) + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 13a6c9239c2d1e..9cdd3e67f75aab 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos import CosConfig, CosS3Client # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -25,7 +25,7 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index de82be04ea87b7..55fe6545ec3d2d 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos +import tos # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -24,6 +24,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) return data def load_stream(self, filename: str) -> Generator: diff --git a/api/factories/__init__.py b/api/factories/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 8538775a67242b..856cf62e3ed243 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -64,7 +64,7 @@ def build_from_mapping( if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file = build_func( + file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, @@ -72,7 +72,7 @@ def build_from_mapping( if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension, + file_extension=file.extension or "", file_transfer_method=file.transfer_method, config=config, ): @@ -116,8 +116,11 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") stmt = select(UploadFile).where( - UploadFile.id == mapping.get("upload_file_id"), + UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) @@ -139,6 +142,7 @@ def _build_from_local_file( remote_url=row.source_url, related_id=mapping.get("upload_file_id"), size=row.size, + storage_key=row.key, ) @@ -168,6 +172,7 @@ def _build_from_remote_url( mime_type=mime_type, extension=extension, size=file_size, + storage_key="", ) @@ -220,6 +225,7 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, + storage_key=tool_file.file_key, ) @@ -275,6 +281,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None: return FileType.AUDIO elif extension in DOCUMENT_EXTENSIONS: return FileType.DOCUMENT + return None def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 16a578728aa16e..bbca8448ec0662 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, cast from uuid import uuid4 from configs import dify_config @@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return result + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: @@ -164,10 +166,13 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=selector, + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), ) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16fe56..1c58b3a2579087 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34dbe7b1..d40407bfcc6193 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index abb27fdad17d63..73800eab853cd3 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 5bd21be80779a4..c54554a6de8405 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -85,7 +85,7 @@ def format(self, value): } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} - +status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer} model_config_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -166,6 +166,7 @@ def format(self, value): "message_count": fields.Integer, "user_feedback_stats": fields.Nested(feedback_stat_fields), "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "status_count": fields.Nested(status_count_fields), } conversation_with_summary_pagination_fields = { diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 983e50e73ceb9f..c6385efb5a3cf1 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376fe6c8..608672121e2b50 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 533e3a0837b815..a74e6f54fb3858 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index a83ec7bc97adee..2b2ac6243f4da5 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1c076..aefa0b27580ca7 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore simple_end_user_fields = { "id": fields.String, diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py index 2281460fe22146..9cc4e14a0575d7 100644 --- a/api/fields/external_dataset_fields.py +++ b/api/fields/external_dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index afaacc0568ea0c..f896c15f0fec70 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index f36e80f8d493d5..aaafcab8ab6ba0 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f67b8c..16f265b9bb6d07 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 1cf8e408d13d32..0c854c640c3f98 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5f6e7884a69c5e..0571faab08c134 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.conversation_fields import message_file_fields from libs.helper import TimestampField diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13e4a8..493d4b6cce7d31 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2dd4cb45be409b..4413af31607897 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd061c..986cd725f70910 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,3 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a53b54624915c2..c45b33597b3978 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0d860d6f406502..bd093d4063bc2e 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 1413adf7196879..ef59c57ec37957 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -14,6 +14,7 @@ "total_steps": fields.Integer, "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } workflow_run_for_list_fields = { @@ -27,6 +28,8 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -42,6 +45,8 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -73,8 +78,22 @@ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } +retry_event_field = { + "elapsed_time": fields.Float, + "status": fields.String, + "inputs": fields.Raw(attribute="inputs"), + "process_data": fields.Raw(attribute="process_data"), + "outputs": fields.Raw(attribute="outputs"), + "metadata": fields.Raw(attribute="metadata"), + "llm_usage": fields.Raw(attribute="llm_usage"), + "error": fields.String, + "retry_index": fields.Integer, +} + + workflow_run_node_execution_fields = { "id": fields.String, "index": fields.Integer, diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 179617ac0a6588..922d2d9cd33324 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,8 +1,9 @@ import re import sys +from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message +from flask_restful import Api, http_status_message # type: ignore from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException @@ -84,7 +85,7 @@ def handle_error(self, e): # record the exception in the logs when we have a server error of status code: 500 if status_code and status_code >= 500: - exc_info = sys.exc_info() + exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None current_app.log_exception(exc_info) @@ -100,7 +101,7 @@ def handle_error(self, e): resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) elif status_code == 400: if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message").items())[0] + param_key, param_value = list(data.get("message", {}).items())[0] data = {"code": "invalid_param", "message": param_value, "params": param_key} else: if "code" not in data: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 83f9c74e339e17..2dae87e1710bf6 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 +import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -191,12 +191,12 @@ def decrypt(self, ciphertext): # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) + invalid = bord(y) | int(one_pos < hLen) # type: ignore hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) + invalid |= bord(x) # type: ignore for x in db[hLen:one_pos]: - invalid |= bord(x) + invalid |= bord(x) # type: ignore if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index 7652d73c8b8db6..eaa4efdb714355 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -9,11 +9,11 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restful import fields # type: ignore from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -168,11 +168,11 @@ def generate_string(n): def extract_remote_ip(request) -> str: if request.headers.get("CF-Connecting-IP"): - return request.headers.get("Cf-Connecting-Ip") + return cast(str, request.headers.get("Cf-Connecting-Ip")) elif request.headers.getlist("X-Forwarded-For"): - return request.headers.getlist("X-Forwarded-For")[0] + return cast(str, request.headers.getlist("X-Forwarded-For")[0]) else: - return request.remote_addr + return cast(str, request.remote_addr) def generate_text_hash(text: str) -> str: @@ -221,12 +221,14 @@ def generate_token( token_data.update(additional_data) expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") + if expiry_minutes is None: + raise ValueError(f"Expiry minutes for {token_type} token is not set") token_key = cls._get_token_key(token, token_type) expiry_time = int(expiry_minutes * 60) redis_client.setex(token_key, expiry_time, json.dumps(token_data)) if account_id: - cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) + cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes) return token @@ -246,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]] if token_data_json is None: logging.warning(f"{token_type} token {token} not found with key {key}") return None - token_data = json.loads(token_data_json) + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: key = cls._get_account_token_key(account_id, token_type) - current_token = redis_client.get(key) + current_token: Optional[str] = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41c5d20c4b08b9..9ab53b6294db93 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict: ends = ["```", "``", "`", "}"] end_index = -1 start_index = 0 + parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: @@ -27,7 +28,7 @@ def parse_json_markdown(json_string: str) -> dict: extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) else: - raise Exception("Could not find JSON block in the output.") + raise ValueError("could not find json block in the output.") return parsed @@ -36,10 +37,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserError(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"got invalid json object. error: {e}") for key in expected_keys: if key not in json_obj: raise OutputParserError( - f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index 0ea191a185785d..5395534a6df93a 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,8 +1,9 @@ from functools import wraps +from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in -from flask_login.config import EXEMPT_METHODS +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -12,7 +13,7 @@ #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user = LocalProxy(lambda: _get_user()) +current_user: Any = LocalProxy(lambda: _get_user()) def login_required(func): @@ -79,12 +80,12 @@ def decorated_view(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: - return current_app.login_manager.unauthorized() + return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 @@ -98,7 +99,7 @@ def decorated_view(*args, **kwargs): def _get_user(): if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() + current_app.login_manager._load_user() # type: ignore return g._login_user diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 6b6919de24f90f..df75b550195298 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -77,9 +77,9 @@ def get_raw_user_info(self, token: str): email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email["primary"] == True), None) + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) - return {**user_info, "email": primary_email["email"]} + return {**user_info, "email": primary_email.get("email", "")} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: email = raw_info.get("email") @@ -130,4 +130,4 @@ def get_raw_user_info(self, token: str): return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 48249e4a353e10..0c872a0066d127 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,8 +1,9 @@ import datetime import urllib.parse +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from extensions.ext_database import db from models.source import DataSourceOauthBinding @@ -226,7 +227,7 @@ def notion_page_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "page", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } @@ -253,6 +254,8 @@ def notion_block_parent_page_id(self, access_token: str, block_id: str): } response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() + if response.status_code != 200: + raise ValueError(f"Error fetching block parent page ID: {response_json.message}") parent = response_json["parent"] parent_type = parent["type"] if parent_type == "block_id": @@ -279,7 +282,7 @@ def notion_database_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "database", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py index d356def418ab1d..e4d63fd3142ce2 100644 --- a/api/libs/threadings_utils.py +++ b/api/libs/threadings_utils.py @@ -9,8 +9,8 @@ def apply_gevent_threading_patch(): :return: """ if not dify_config.DEBUG: - from gevent import monkey - from grpc.experimental import gevent as grpc_gevent + from gevent import monkey # type: ignore + from grpc.experimental import gevent as grpc_gevent # type: ignore # gevent monkey.patch_all() diff --git a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py new file mode 100644 index 00000000000000..8c576339bae8cf --- /dev/null +++ b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py @@ -0,0 +1,33 @@ +"""add exceptions_count field to WorkflowRun model + +Revision ID: cf8f4fc45278 +Revises: 01d6889832f7 +Create Date: 2024-11-28 05:53:21.576178 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf8f4fc45278' +down_revision = '01d6889832f7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_column('exceptions_count') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py new file mode 100644 index 00000000000000..881a9e3c1e06b6 --- /dev/null +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -0,0 +1,39 @@ +"""remove unused tool_providers + +Revision ID: 11b07f66c737 +Revises: cf8f4fc45278 +Create Date: 2024-12-19 17:46:25.780116 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '11b07f66c737' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_providers') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py new file mode 100644 index 00000000000000..814dec423c63c4 --- /dev/null +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -0,0 +1,37 @@ +"""add retry_index field to node-execution model +Revision ID: e1944c35e15e +Revises: 11b07f66c737 +Create Date: 2024-12-20 06:28:30.287197 +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e1944c35e15e' +down_revision = '11b07f66c737' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # We don't need these fields anymore, but this file is already merged into the main branch, + # so we need to keep this file for the sake of history, and this change will be reverted in the next migration. + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + pass + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.drop_column('retry_index') + pass + + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py new file mode 100644 index 00000000000000..ea129d15f7e6e6 --- /dev/null +++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py @@ -0,0 +1,34 @@ +"""remove workflow_node_executions.retry_index if exists + +Revision ID: d7999dfa4aae +Revises: e1944c35e15e +Create Date: 2024-12-23 11:54:15.344543 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy import inspect + + +# revision identifiers, used by Alembic. +revision = 'd7999dfa4aae' +down_revision = 'e1944c35e15e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Check if column exists before attempting to remove it + conn = op.get_bind() + inspector = inspect(conn) + has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')] + + if has_column: + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + +def downgrade(): + # No downgrade needed as we don't want to restore the column + pass diff --git a/api/models/__init__.py b/api/models/__init__.py index 61a38870cf9b91..b0b9880ca42a9d 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,53 +1,187 @@ -from .account import Account, AccountIntegrate, InvitationCode, Tenant -from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .account import ( + Account, + AccountIntegrate, + AccountStatus, + InvitationCode, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetKeywordTable, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + Embedding, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, + TidbAuthBinding, + Whitelist, +) +from .engine import db +from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( + ApiRequest, ApiToken, App, + AppAnnotationHitHistory, + AppAnnotationSetting, AppMode, + AppModelConfig, Conversation, + DatasetRetrieverResource, + DifySetup, EndUser, + IconType, InstalledApp, Message, + MessageAgentThought, MessageAnnotation, + MessageChain, + MessageFeedback, MessageFile, + OperationLog, RecommendedApp, Site, + Tag, + TagBinding, + TraceAppConfig, UploadFile, ) -from .source import DataSourceOauthBinding -from .tools import ToolFile +from .provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderOrder, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) +from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .task import CeleryTask, CeleryTaskSet +from .tools import ( + ApiToolProvider, + BuiltinToolProvider, + PublishedAppTool, + ToolConversationVariables, + ToolFile, + ToolLabelBinding, + ToolModelInvoke, + WorkflowToolProvider, +) +from .web import PinnedConversation, SavedMessage from .workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, WorkflowRun, + WorkflowRunStatus, + WorkflowType, ) __all__ = [ + "APIBasedExtension", + "APIBasedExtensionPoint", "Account", "AccountIntegrate", + "AccountStatus", + "ApiRequest", "ApiToken", + "ApiToolProvider", # Added "App", + "AppAnnotationHitHistory", + "AppAnnotationSetting", + "AppDatasetJoin", "AppMode", + "AppModelConfig", + "BuiltinToolProvider", # Added + "CeleryTask", + "CeleryTaskSet", "Conversation", "ConversationVariable", + "CreatedByRole", + "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", + "DatasetCollectionBinding", + "DatasetKeywordTable", + "DatasetPermission", + "DatasetPermissionEnum", "DatasetProcessRule", + "DatasetQuery", + "DatasetRetrieverResource", + "DifySetup", "Document", "DocumentSegment", + "Embedding", "EndUser", + "ExternalKnowledgeApis", + "ExternalKnowledgeBindings", + "IconType", "InstalledApp", "InvitationCode", + "LoadBalancingModelConfig", "Message", + "MessageAgentThought", "MessageAnnotation", + "MessageChain", + "MessageFeedback", "MessageFile", + "OperationLog", + "PinnedConversation", + "Provider", + "ProviderModel", + "ProviderModelSetting", + "ProviderOrder", + "ProviderQuotaType", + "ProviderType", + "PublishedAppTool", "RecommendedApp", + "SavedMessage", "Site", + "Tag", + "TagBinding", "Tenant", + "TenantAccountJoin", + "TenantAccountJoinRole", + "TenantAccountRole", + "TenantDefaultModel", + "TenantPreferredModelProvider", + "TenantStatus", + "TidbAuthBinding", + "ToolConversationVariables", "ToolFile", + "ToolLabelBinding", + "ToolModelInvoke", + "TraceAppConfig", "UploadFile", + "UserFrom", + "Whitelist", "Workflow", "WorkflowAppLog", + "WorkflowAppLogCreatedFrom", + "WorkflowNodeExecution", + "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionTriggeredFrom", "WorkflowRun", + "WorkflowRunStatus", + "WorkflowRunTriggeredFrom", + "WorkflowToolProvider", + "WorkflowType", + "db", ] diff --git a/api/models/account.py b/api/models/account.py index 951e836dec1873..88c96da1a149d5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,10 +1,10 @@ import enum import json -from flask_login import UserMixin - -from extensions.ext_database import db +from flask_login import UserMixin # type: ignore +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -16,7 +16,7 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, db.Model): +class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) @@ -31,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def is_password_set(self): @@ -43,7 +43,8 @@ def is_password_set(self): @property def current_tenant(self): - return self._current_tenant + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore @current_tenant.setter def current_tenant(self, value: "Tenant"): @@ -52,7 +53,8 @@ def current_tenant(self, value: "Tenant"): if ta: tenant.current_role = ta.role else: - tenant = None + # FIXME: fix the type error later, because the type is important maybe cause some bugs + tenant = None # type: ignore self._current_tenant = tenant @property @@ -89,7 +91,7 @@ def get_status(self) -> AccountStatus: return AccountStatus(status_str) @classmethod - def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) @@ -99,11 +101,6 @@ def get_by_openid(cls, provider: str, open_id: str) -> db.Model: return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None - def get_integrates(self) -> list[db.Model]: - ai = db.Model - return db.session.query(ai).filter(ai.account_id == self.id).all() - - # check current_user.current_tenant.current_role in ['admin', 'owner'] @property def is_admin_or_owner(self): return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) @@ -139,7 +136,7 @@ class TenantAccountRole(enum.StrEnum): @staticmethod def is_valid_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -149,15 +146,15 @@ def is_valid_role(role: str) -> bool: @staticmethod def is_privileged_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} @staticmethod def is_admin_role(role: str) -> bool: - return role and role == TenantAccountRole.ADMIN + return role == TenantAccountRole.ADMIN @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, @@ -166,11 +163,11 @@ def is_non_owner_role(role: str) -> bool: @staticmethod def is_editing_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -178,7 +175,7 @@ def is_dataset_edit_role(role: str) -> bool: } -class Tenant(db.Model): +class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -188,8 +185,8 @@ class Tenant(db.Model): plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -214,7 +211,7 @@ class TenantAccountJoinRole(enum.Enum): DATASET_OPERATOR = "dataset_operator" -class TenantAccountJoin(db.Model): +class TenantAccountJoin(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -229,11 +226,11 @@ class TenantAccountJoin(db.Model): current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): +class AccountIntegrate(db.Model): # type: ignore[name-defined] __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -246,11 +243,11 @@ class AccountIntegrate(db.Model): provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): +class InvitationCode(db.Model): # type: ignore[name-defined] __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), @@ -266,4 +263,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 97173747afc4b1..6b6d808710afc0 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum -from extensions.ext_database import db +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -12,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): +class APIBasedExtension(db.Model): # type: ignore[name-defined] __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), @@ -24,4 +25,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 8ab957e875a1bf..b9b41dcf475bb1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,16 +9,17 @@ import re import time from json import JSONDecodeError +from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from extensions.ext_storage import storage from .account import Account +from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID @@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): +class Dataset(db.Model): # type: ignore[name-defined] __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -50,9 +51,9 @@ class Dataset(db.Model): indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -200,7 +201,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str: return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): +class DatasetProcessRule(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -212,11 +213,11 @@ class DatasetProcessRule(db.Model): mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES = { + AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -242,7 +243,7 @@ def rules_dict(self): return None -class Document(db.Model): +class Document(db.Model): # type: ignore[name-defined] __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -264,7 +265,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -303,7 +304,7 @@ class Document(db.Model): archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) @@ -492,7 +493,7 @@ def from_dict(cls, data: dict): ) -class DocumentSegment(db.Model): +class DocumentSegment(db.Model): # type: ignore[name-defined] __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -527,9 +528,9 @@ class DocumentSegment(db.Model): disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -604,7 +605,7 @@ def get_sign_content(self): return text -class AppDatasetJoin(db.Model): +class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -621,7 +622,7 @@ def app(self): return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): +class DatasetQuery(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -638,7 +639,7 @@ class DatasetQuery(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -683,7 +684,7 @@ def object_hook(self, dct): return None -class Embedding(db.Model): +class Embedding(db.Model): # type: ignore[name-defined] __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -697,17 +698,17 @@ class Embedding(db.Model): ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return cast(list[float], pickle.loads(self.embedding)) -class DatasetCollectionBinding(db.Model): +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -719,10 +720,10 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): +class TidbAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -739,10 +740,10 @@ class TidbAuthBinding(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): +class Whitelist(db.Model): # type: ignore[name-defined] __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -751,10 +752,10 @@ class Whitelist(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): +class DatasetPermission(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -768,10 +769,10 @@ class DatasetPermission(db.Model): account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -785,9 +786,9 @@ class ExternalKnowledgeApis(db.Model): tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -824,7 +825,7 @@ def dataset_bindings(self): return dataset_bindings -class ExternalKnowledgeBindings(db.Model): +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), @@ -840,6 +841,6 @@ class ExternalKnowledgeBindings(db.Model): dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/engine.py b/api/models/engine.py new file mode 100644 index 00000000000000..dda93bc9415cfc --- /dev/null +++ b/api/models/engine.py @@ -0,0 +1,13 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData + +POSTGRES_INDEXES_NAMING_CONVENTION = { + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + +metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) +db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 03b8e0bea553aa..2a593f08298199 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,11 +4,11 @@ from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -16,20 +16,24 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from extensions.ext_database import db from libs.helper import generate_string from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus from .account import Account, Tenant +from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from .workflow import Workflow -class DifySetup(db.Model): + +class DifySetup(db.Model): # type: ignore[name-defined] __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -59,7 +63,7 @@ class IconType(Enum): EMOJI = "emoji" -class App(db.Model): +class App(db.Model): # type: ignore[name-defined] __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) @@ -82,11 +86,11 @@ class App(db.Model): is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) - max_active_requests = db.Column(db.Integer, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -150,7 +154,7 @@ def mode_compatible_with_agent(self) -> str: if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return self.mode + return str(self.mode) @property def deleted_tools(self) -> list: @@ -215,7 +219,7 @@ def tags(self): return tags or [] -class AppModelConfig(db.Model): +class AppModelConfig(db.Model): # type: ignore[name-defined] __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) @@ -225,9 +229,9 @@ class AppModelConfig(db.Model): model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -318,7 +322,7 @@ def external_data_tools_list(self) -> list[dict]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> dict: + def user_input_form_list(self) -> list[dict]: return json.loads(self.user_input_form) if self.user_input_form else [] @property @@ -340,7 +344,7 @@ def completion_prompt_config_dict(self) -> dict: @property def dataset_configs_dict(self) -> dict: if self.dataset_configs: - dataset_configs = json.loads(self.dataset_configs) + dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -462,7 +466,7 @@ def copy(self): return new_app_model_config -class RecommendedApp(db.Model): +class RecommendedApp(db.Model): # type: ignore[name-defined] __tablename__ = "recommended_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -481,8 +485,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -490,7 +494,7 @@ def app(self): return app -class InstalledApp(db.Model): +class InstalledApp(db.Model): # type: ignore[name-defined] __tablename__ = "installed_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="installed_app_pkey"), @@ -506,7 +510,7 @@ class InstalledApp(db.Model): position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -519,7 +523,7 @@ def tenant(self): return tenant -class Conversation(db.Model): +class Conversation(db.Model): # type: ignore[name-defined] __tablename__ = "conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="conversation_pkey"), @@ -547,8 +551,8 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( @@ -560,13 +564,29 @@ class Conversation(db.Model): @property def inputs(self): inputs = self._inputs.copy() + + # Convert file mapping to File object for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + return inputs @inputs.setter @@ -582,6 +602,8 @@ def inputs(self, value: Mapping[str, Any]): @property def model_config(self): model_config = {} + app_model_config: Optional[AppModelConfig] = None + if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -593,6 +615,7 @@ def model_config(self): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -679,6 +702,31 @@ def admin_feedback_stats(self): return {"like": like, "dislike": dislike} + @property + def status_count(self): + messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + status_counts = { + WorkflowRunStatus.RUNNING: 0, + WorkflowRunStatus.SUCCEEDED: 0, + WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, + WorkflowRunStatus.PARTIAL_SUCCESSED: 0, + } + + for message in messages: + if message.workflow_run: + status_counts[message.workflow_run.status] += 1 + + return ( + { + "success": status_counts[WorkflowRunStatus.SUCCEEDED], + "failed": status_counts[WorkflowRunStatus.FAILED], + "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCESSED], + } + if messages + else None + ) + @property def first_message(self): return db.session.query(Message).filter(Message.conversation_id == self.id).first() @@ -710,7 +758,7 @@ def in_debug_mode(self): return self.override_model_configs is not None -class Message(db.Model): +class Message(db.Model): # type: ignore[name-defined] __tablename__ = "messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_pkey"), @@ -749,8 +797,8 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -758,12 +806,25 @@ class Message(db.Model): def inputs(self): inputs = self._inputs.copy() for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) return inputs @inputs.setter @@ -937,7 +998,7 @@ def message_files(self): if not current_app: raise ValueError(f"App {self.app_id} not found") - files: list[File] = [] + files = [] for message_file in message_files: if message_file.transfer_method == "local_file": if message_file.upload_file_id is None: @@ -1044,7 +1105,7 @@ def from_dict(cls, data: dict): ) -class MessageFeedback(db.Model): +class MessageFeedback(db.Model): # type: ignore[name-defined] __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1062,8 +1123,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1071,7 +1132,7 @@ def from_account(self): return account -class MessageFile(db.Model): +class MessageFile(db.Model): # type: ignore[name-defined] __tablename__ = "message_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1109,12 +1170,10 @@ def __init__( upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageAnnotation(db.Model): +class MessageAnnotation(db.Model): # type: ignore[name-defined] __tablename__ = "message_annotations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1131,8 +1190,8 @@ class MessageAnnotation(db.Model): content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1145,7 +1204,7 @@ def annotation_create_account(self): return account -class AppAnnotationHitHistory(db.Model): +class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1161,7 +1220,7 @@ class AppAnnotationHitHistory(db.Model): source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) @@ -1183,7 +1242,7 @@ def annotation_create_account(self): return account -class AppAnnotationSetting(db.Model): +class AppAnnotationSetting(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1195,9 +1254,9 @@ class AppAnnotationSetting(db.Model): score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_account(self): @@ -1231,7 +1290,7 @@ def collection_binding_detail(self): return collection_binding_detail -class OperationLog(db.Model): +class OperationLog(db.Model): # type: ignore[name-defined] __tablename__ = "operation_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="operation_log_pkey"), @@ -1243,12 +1302,12 @@ class OperationLog(db.Model): account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class EndUser(UserMixin, db.Model): +class EndUser(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"), @@ -1264,11 +1323,11 @@ class EndUser(UserMixin, db.Model): name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Site(db.Model): +class Site(db.Model): # type: ignore[name-defined] __tablename__ = "sites" __table_args__ = ( db.PrimaryKeyConstraint("id", name="site_pkey"), @@ -1296,9 +1355,9 @@ class Site(db.Model): prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) code = db.Column(db.String(255)) @property @@ -1325,7 +1384,7 @@ def app_base_url(self): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(db.Model): +class ApiToken(db.Model): # type: ignore[name-defined] __tablename__ = "api_tokens" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1340,7 +1399,7 @@ class ApiToken(db.Model): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1352,7 +1411,7 @@ def generate_api_key(prefix, n): return result -class UploadFile(db.Model): +class UploadFile(db.Model): # type: ignore[name-defined] __tablename__ = "upload_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -1371,9 +1430,7 @@ class UploadFile(db.Model): db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) @@ -1416,7 +1473,7 @@ def __init__( self.source_url = source_url -class ApiRequest(db.Model): +class ApiRequest(db.Model): # type: ignore[name-defined] __tablename__ = "api_requests" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_request_pkey"), @@ -1430,10 +1487,10 @@ class ApiRequest(db.Model): request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageChain(db.Model): +class MessageChain(db.Model): # type: ignore[name-defined] __tablename__ = "message_chains" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_chain_pkey"), @@ -1448,7 +1505,7 @@ class MessageChain(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class MessageAgentThought(db.Model): +class MessageAgentThought(db.Model): # type: ignore[name-defined] __tablename__ = "message_agent_thoughts" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1488,7 +1545,7 @@ class MessageAgentThought(db.Model): @property def files(self) -> list: if self.message_files: - return json.loads(self.message_files) + return cast(list[Any], json.loads(self.message_files)) else: return [] @@ -1500,7 +1557,7 @@ def tools(self) -> list[str]: def tool_labels(self) -> dict: try: if self.tool_labels_str: - return json.loads(self.tool_labels_str) + return cast(dict, json.loads(self.tool_labels_str)) else: return {} except Exception as e: @@ -1510,7 +1567,7 @@ def tool_labels(self) -> dict: def tool_meta(self) -> dict: try: if self.tool_meta_str: - return json.loads(self.tool_meta_str) + return cast(dict, json.loads(self.tool_meta_str)) else: return {} except Exception as e: @@ -1558,9 +1615,11 @@ def tool_outputs_dict(self) -> dict: except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) + else: + return {} -class DatasetRetrieverResource(db.Model): +class DatasetRetrieverResource(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_retriever_resources" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), @@ -1587,7 +1646,7 @@ class DatasetRetrieverResource(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class Tag(db.Model): +class Tag(db.Model): # type: ignore[name-defined] __tablename__ = "tags" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1602,10 +1661,10 @@ class Tag(db.Model): type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TagBinding(db.Model): +class TagBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tag_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), @@ -1618,10 +1677,10 @@ class TagBinding(db.Model): tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TraceAppConfig(db.Model): +class TraceAppConfig(db.Model): # type: ignore[name-defined] __tablename__ = "trace_app_config" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), @@ -1632,8 +1691,10 @@ class TraceAppConfig(db.Model): app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property diff --git a/api/models/provider.py b/api/models/provider.py index 644915e781084b..abe673975c1ccc 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,8 @@ from enum import Enum -from extensions.ext_database import db +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -35,7 +36,7 @@ def value_of(value): raise ValueError(f"No matching enum found for value '{value}'") -class Provider(db.Model): +class Provider(db.Model): # type: ignore[name-defined] """ Provider model representing the API providers and their configurations. """ @@ -61,8 +62,8 @@ class Provider(db.Model): quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -88,7 +89,7 @@ def is_enabled(self): return self.is_valid and self.token_is_set -class ProviderModel(db.Model): +class ProviderModel(db.Model): # type: ignore[name-defined] """ Provider model representing the API provider_models and their configurations. """ @@ -109,11 +110,11 @@ class ProviderModel(db.Model): model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantDefaultModel(db.Model): +class TenantDefaultModel(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_default_models" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), @@ -125,11 +126,11 @@ class TenantDefaultModel(db.Model): provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantPreferredModelProvider(db.Model): +class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), @@ -140,11 +141,11 @@ class TenantPreferredModelProvider(db.Model): tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderOrder(db.Model): +class ProviderOrder(db.Model): # type: ignore[name-defined] __tablename__ = "provider_orders" __table_args__ = ( db.PrimaryKeyConstraint("id", name="provider_order_pkey"), @@ -165,11 +166,11 @@ class ProviderOrder(db.Model): paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderModelSetting(db.Model): +class ProviderModelSetting(db.Model): # type: ignore[name-defined] """ Provider model settings for record the model enabled status and load balancing status. """ @@ -187,11 +188,11 @@ class ProviderModelSetting(db.Model): model_type = db.Column(db.String(40), nullable=False) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class LoadBalancingModelConfig(db.Model): +class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined] """ Configurations for load balancing models. """ @@ -210,5 +211,5 @@ class LoadBalancingModelConfig(db.Model): name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 07695f06e6cf00..881cfaac7d3998 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,13 +1,13 @@ import json +from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB -from extensions.ext_database import db - +from .engine import db from .types import StringUUID -class DataSourceOauthBinding(db.Model): +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -20,12 +20,12 @@ class DataSourceOauthBinding(db.Model): access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) -class DataSourceApiKeyAuthBinding(db.Model): +class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), @@ -38,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model): category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): diff --git a/api/models/task.py b/api/models/task.py index 5d89ff85acc781..0db1c632299fcb 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,11 +1,11 @@ from datetime import UTC, datetime -from celery import states +from celery import states # type: ignore -from extensions.ext_database import db +from .engine import db -class CeleryTask(db.Model): +class CeleryTask(db.Model): # type: ignore[name-defined] """Task result/status.""" __tablename__ = "celery_taskmeta" @@ -29,7 +29,7 @@ class CeleryTask(db.Model): queue = db.Column(db.String(155), nullable=True) -class CeleryTaskSet(db.Model): +class CeleryTaskSet(db.Model): # type: ignore[name-defined] """TaskSet result.""" __tablename__ = "celery_tasksetmeta" diff --git a/api/models/tool.py b/api/models/tool.py deleted file mode 100644 index a81bb65174a724..00000000000000 --- a/api/models/tool.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from enum import Enum - -from extensions.ext_database import db - -from .types import StringUUID - - -class ToolProviderName(Enum): - SERPAPI = "serpapi" - - @staticmethod - def value_of(value): - for member in ToolProviderName: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class ToolProvider(db.Model): - __tablename__ = "tool_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), - db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), - ) - - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - tool_name = db.Column(db.String(40), nullable=False) - encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - - @property - def credentials_is_set(self): - """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ - return self.encrypted_credentials is not None - - @property - def credentials(self): - """ - Returns the decrypted config. - """ - return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None diff --git a/api/models/tools.py b/api/models/tools.py index 4040339e026474..13a112ee83b513 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,20 +1,20 @@ import json -from typing import Optional +from typing import Any, Optional import sqlalchemy as sa -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from extensions.ext_database import db +from .engine import db from .model import Account, App, Tenant from .types import StringUUID -class BuiltinToolProvider(db.Model): +class BuiltinToolProvider(db.Model): # type: ignore[name-defined] """ This table stores the tool provider information for built-in tools for each tenant. """ @@ -36,15 +36,15 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def credentials(self) -> dict: - return json.loads(self.encrypted_credentials) + return dict(json.loads(self.encrypted_credentials)) -class PublishedAppTool(db.Model): +class PublishedAppTool(db.Model): # type: ignore[name-defined] """ The table stores the apps published as a tool for each person. """ @@ -74,19 +74,19 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) @property - def app(self) -> App: + def app(self): return db.session.query(App).filter(App.id == self.app_id).first() -class ApiToolProvider(db.Model): +class ApiToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the api providers. """ @@ -120,8 +120,8 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -133,7 +133,7 @@ def tools(self) -> list[ApiToolBundle]: @property def credentials(self) -> dict: - return json.loads(self.credentials_str) + return dict(json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -144,7 +144,7 @@ def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() -class ToolLabelBinding(db.Model): +class ToolLabelBinding(db.Model): # type: ignore[name-defined] """ The table stores the labels for tools. """ @@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model): label_name = db.Column(db.String(40), nullable=False) -class WorkflowToolProvider(db.Model): +class WorkflowToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the workflow providers. """ @@ -198,12 +198,8 @@ class WorkflowToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - - @property - def schema_type(self) -> ApiProviderSchemaType: - return ApiProviderSchemaType.value_of(self.schema_type_str) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def user(self) -> Account | None: @@ -222,7 +218,7 @@ def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() -class ToolModelInvoke(db.Model): +class ToolModelInvoke(db.Model): # type: ignore[name-defined] """ store the invoke logs from tool invoke """ @@ -255,11 +251,11 @@ class ToolModelInvoke(db.Model): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ToolConversationVariables(db.Model): +class ToolConversationVariables(db.Model): # type: ignore[name-defined] """ store the conversation variables from tool invoke """ @@ -282,15 +278,15 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> dict: + def variables(self) -> Any: return json.loads(self.variables_str) -class ToolFile(db.Model): +class ToolFile(db.Model): # type: ignore[name-defined] __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), diff --git a/api/models/web.py b/api/models/web.py index bc088c185d5a8b..864428fe0931b6 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,10 +1,12 @@ -from extensions.ext_database import db +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column +from .engine import db from .model import Message from .types import StringUUID -class SavedMessage(db.Model): +class SavedMessage(db.Model): # type: ignore[name-defined] __tablename__ = "saved_messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="saved_message_pkey"), @@ -16,14 +18,14 @@ class SavedMessage(db.Model): message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): return db.session.query(Message).filter(Message.id == self.message_id).first() -class PinnedConversation(db.Model): +class PinnedConversation(db.Model): # type: ignore[name-defined] __tablename__ = "pinned_conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), @@ -32,7 +34,7 @@ class PinnedConversation(db.Model): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index c0e70889a88875..880e044d073a67 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import sqlalchemy as sa from sqlalchemy import func @@ -12,14 +12,17 @@ from constants import HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Variable -from extensions.ext_database import db from factories import variable_factory from libs import helper from models.enums import CreatedByRole from .account import Account +from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from models.model import AppMode, Message + class WorkflowType(Enum): """ @@ -56,7 +59,7 @@ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT -class Workflow(db.Model): +class Workflow(db.Model): # type: ignore[name-defined] """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -103,12 +106,13 @@ class Workflow(db.Model): graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -181,7 +185,7 @@ def features(self, value: str) -> None: self._features = value @property - def features_dict(self) -> Mapping[str, Any]: + def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -198,7 +202,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] # get user_input_form from start node - variables = start_node.get("data", {}).get("variables", []) + variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] @@ -225,8 +229,10 @@ def tool_published(self) -> bool: from models.tools import WorkflowToolProvider return ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() - is not None + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) + .count() + > 0 ) @property @@ -325,6 +331,7 @@ class WorkflowRunStatus(StrEnum): SUCCEEDED = "succeeded" FAILED = "failed" STOPPED = "stopped" + PARTIAL_SUCCESSED = "partial-succeeded" @classmethod def value_of(cls, value: str) -> "WorkflowRunStatus": @@ -340,7 +347,7 @@ def value_of(cls, value: str) -> "WorkflowRunStatus": raise ValueError(f"invalid workflow run status value {value}") -class WorkflowRun(db.Model): +class WorkflowRun(db.Model): # type: ignore[name-defined] """ Workflow Run @@ -395,16 +402,17 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped - outputs: Mapped[str] = mapped_column(sa.Text, default="{}") + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) + exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): @@ -464,6 +472,7 @@ def to_dict(self): "created_by": self.created_by, "created_at": self.created_at, "finished_at": self.finished_at, + "exceptions_count": self.exceptions_count, } @classmethod @@ -489,6 +498,7 @@ def from_dict(cls, data: dict) -> "WorkflowRun": created_by=data.get("created_by"), created_at=data.get("created_at"), finished_at=data.get("finished_at"), + exceptions_count=data.get("exceptions_count"), ) @@ -522,6 +532,8 @@ class WorkflowNodeExecutionStatus(Enum): RUNNING = "running" SUCCEEDED = "succeeded" FAILED = "failed" + EXCEPTION = "exception" + RETRY = "retry" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": @@ -537,7 +549,7 @@ def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": raise ValueError(f"invalid workflow node execution status value {value}") -class WorkflowNodeExecution(db.Model): +class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] """ Workflow Node Execution @@ -628,7 +640,7 @@ class WorkflowNodeExecution(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -703,7 +715,7 @@ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": raise ValueError(f"invalid workflow app log created from value {value}") -class WorkflowAppLog(db.Model): +class WorkflowAppLog(db.Model): # type: ignore[name-defined] """ Workflow App execution log, excluding workflow debugging records. @@ -746,7 +758,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -765,14 +777,14 @@ def created_by_end_user(self): return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None -class ConversationVariable(db.Model): +class ConversationVariable(db.Model): # type: ignore[name-defined] __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp()) updated_at = db.Column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) diff --git a/api/mypy.ini b/api/mypy.ini new file mode 100644 index 00000000000000..2c754f9fcd7c63 --- /dev/null +++ b/api/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +check_untyped_defs = True +exclude = (?x)( + core/tools/provider/builtin/ + | core/model_runtime/model_providers/ + | tests/ + | migrations/ + ) \ No newline at end of file diff --git a/api/poetry.lock b/api/poetry.lock index 4c784e53cda4b1..b42eb22dd40b8a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -483,22 +483,23 @@ vertex = ["google-auth (>=2,<3)"] [[package]] name = "anyio" -version = "4.6.2.post1" +version = "4.7.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" files = [ - {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, - {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, + {file = "anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352"}, + {file = "anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48"}, ] [package.dependencies] idna = ">=2.8" sniffio = ">=1.1" +typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] -doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] trio = ["trio (>=0.26.1)"] [[package]] @@ -869,13 +870,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.74" +version = "1.35.76" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.74-py3-none-any.whl", hash = "sha256:9ac9d33d84dd9f05b35085de081552342a2c9ae22e3c4ee105723c9e92c07bd9"}, - {file = "botocore-1.35.74.tar.gz", hash = "sha256:de5c4fa9a24cef3a758974857b5c5820a12fad345ebf33c052a5988e88f33634"}, + {file = "botocore-1.35.76-py3-none-any.whl", hash = "sha256:b4729d12d00267b3185628f83543917b6caae292385230ab464067621aa086af"}, + {file = "botocore-1.35.76.tar.gz", hash = "sha256:a75a42ae53395796b8300c5fefb2d65a8696dc40dc85e49cf3a769e0c0202b13"}, ] [package.dependencies] @@ -2076,17 +2077,17 @@ tokenizer = ["tiktoken"] [[package]] name = "dataclass-wizard" -version = "0.32.0" +version = "0.32.1" description = "Effortlessly marshal dataclasses to/from JSON. Leverage field properties with default values. Generate dataclass schemas from JSON input." optional = false python-versions = "*" files = [ - {file = "dataclass-wizard-0.32.0.tar.gz", hash = "sha256:b9411bc91a9a0e2224ca6a599923b8e472b170acc14580b2fa6fcf343f720fe5"}, - {file = "dataclass_wizard-0.32.0-py2.py3-none-any.whl", hash = "sha256:36091a8d5b49b43178bf076c948ff5b848d36e42ad20adf78ae2d0312e1c09e4"}, + {file = "dataclass-wizard-0.32.1.tar.gz", hash = "sha256:31d44224ff8acb28abb1bbf11afa5fa73d7eeec8cb2e8f9aed374135c154b617"}, + {file = "dataclass_wizard-0.32.1-py2.py3-none-any.whl", hash = "sha256:ce2c3bbfe48b197162ffeffd74c2d4ae4ca834acf4017b001a906a07109c943b"}, ] [package.extras] -dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dacite (==1.8.1)", "dataclass-factory (==2.16)", "dataclass-wizard[toml]", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "mashumaro (==3.15)", "pip (>=21.3.1)", "pydantic (==2.10.2)", "pytest (==8.3.3)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "python-dotenv (>=1,<2)", "pytimeparse (==1.1.8)", "sphinx-autodoc-typehints (==2.5.0)", "sphinx-copybutton (==0.5.2)", "sphinx-issues (==5.0.0)", "tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)", "tox (==4.23.2)", "twine (==5.1.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.1)"] +dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dacite (==1.8.1)", "dataclass-factory (==2.16)", "dataclass-wizard[toml]", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "mashumaro (==3.15)", "pip (>=21.3.1)", "pydantic (==2.10.2)", "pytest (==8.3.4)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "python-dotenv (>=1,<2)", "pytimeparse (==1.1.8)", "sphinx-autodoc-typehints (==2.5.0)", "sphinx-copybutton (==0.5.2)", "sphinx-issues (==5.0.0)", "tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)", "tox (==4.23.2)", "twine (==6.0.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.1)"] dotenv = ["python-dotenv (>=1,<2)"] timedelta = ["pytimeparse (>=1.1.7)"] toml = ["tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)"] @@ -2810,61 +2811,61 @@ fonttools = "*" [[package]] name = "fonttools" -version = "4.55.1" +version = "4.55.2" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" files = [ - {file = "fonttools-4.55.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c17a6f9814f83772cd6d9c9009928e1afa4ab66210a31ced721556651075a9a0"}, - {file = "fonttools-4.55.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c4d14eecc814826a01db87a40af3407c892ba49996bc6e49961e386cd78b537c"}, - {file = "fonttools-4.55.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8589f9a15dc005592b94ecdc45b4dfae9bbe9e73542e89af5a5e776e745db83b"}, - {file = "fonttools-4.55.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfee95bd9395bcd9e6c78955387554335109b6a613db71ef006020b42f761c58"}, - {file = "fonttools-4.55.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:34fa2ecc0bf1923d1a51bf2216a006de2c3c0db02c6aa1470ea50b62b8619bd5"}, - {file = "fonttools-4.55.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9c1c48483148bfb1b9ad951133ceea957faa004f6cb475b67e7bc75d482b48f8"}, - {file = "fonttools-4.55.1-cp310-cp310-win32.whl", hash = "sha256:3e2fc388ca7d023b3c45badd71016fd4185f93e51a22cfe4bd65378af7fba759"}, - {file = "fonttools-4.55.1-cp310-cp310-win_amd64.whl", hash = "sha256:c4c36c71f69d2b3ee30394b0986e5f8b2c461e7eff48dde49b08a90ded9fcdbd"}, - {file = "fonttools-4.55.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5daab3a55d460577f45bb8f5a8eca01fa6cde43ef2ab943b527991f54b735c41"}, - {file = "fonttools-4.55.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:acf1e80cf96c2fbc79e46f669d8713a9a79faaebcc68e31a9fbe600cf8027992"}, - {file = "fonttools-4.55.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e88a0329f7f88a210f09f79c088fb64f8032fc3ab65e2390a40b7d3a11773026"}, - {file = "fonttools-4.55.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03105b42259a8a94b2f0cbf1bee45f7a8a34e7b26c946a8fb89b4967e44091a8"}, - {file = "fonttools-4.55.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9af3577e821649879ab5774ad0e060af34816af556c77c6d3820345d12bf415e"}, - {file = "fonttools-4.55.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34bd5de3d0ad085359b79a96575cd6bd1bc2976320ef24a2aa152ead36dbf656"}, - {file = "fonttools-4.55.1-cp311-cp311-win32.whl", hash = "sha256:5da92c4b637f0155a41f345fa81143c8e17425260fcb21521cb2ad4d2cea2a95"}, - {file = "fonttools-4.55.1-cp311-cp311-win_amd64.whl", hash = "sha256:f70234253d15f844e6da1178f019a931f03181463ce0c7b19648b8c370527b07"}, - {file = "fonttools-4.55.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9c372e527d58ba64b695f15f8014e97bc8826cf64d3380fc89b4196edd3c0fa8"}, - {file = "fonttools-4.55.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:845a967d3bef3245ba81fb5582dc731f6c2c8417fa211f1068c56893504bc000"}, - {file = "fonttools-4.55.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f03be82bcd4ba4418adf10e6165743f824bb09d6594c2743d7f93ea50968805b"}, - {file = "fonttools-4.55.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c42e935cf146f826f556d977660dac88f2fa3fb2efa27d5636c0b89a60c16edf"}, - {file = "fonttools-4.55.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:96328bf91e05621d8e40d9f854af7a262cb0e8313e9b38e7f3a7f3c4c0caaa8b"}, - {file = "fonttools-4.55.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:291acec4d774e8cd2d8472d88c04643a77a3324a15247951bd6cfc969799b69e"}, - {file = "fonttools-4.55.1-cp312-cp312-win32.whl", hash = "sha256:6d768d6632809aec1c3fa8f195b173386d85602334701a6894a601a4d3c80368"}, - {file = "fonttools-4.55.1-cp312-cp312-win_amd64.whl", hash = "sha256:2a3850afdb0be1f79a1e95340a2059226511675c5b68098d4e49bfbeb48a8aab"}, - {file = "fonttools-4.55.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0c88d427eaf8bd8497b9051f56e0f5f9fb96a311aa7c72cda35e03e18d59cd16"}, - {file = "fonttools-4.55.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f062c95a725a79fd908fe8407b6ad63e230e1c7d6dece2d5d6ecaf843d6927f6"}, - {file = "fonttools-4.55.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f298c5324c45cad073475146bf560f4110ce2dc2488ff12231a343ec489f77bc"}, - {file = "fonttools-4.55.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f06dbb71344ffd85a6cb7e27970a178952f0bdd8d319ed938e64ba4bcc41700"}, - {file = "fonttools-4.55.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4c46b3525166976f5855b1f039b02433dc51eb635fb54d6a111e0c5d6e6cdc4c"}, - {file = "fonttools-4.55.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:af46f52a21e086a2f89b87bd941c9f0f91e5f769e1a5eb3b37c912228814d3e5"}, - {file = "fonttools-4.55.1-cp313-cp313-win32.whl", hash = "sha256:cd7f36335c5725a3fd724cc667c10c3f5254e779bdc5bffefebb33cf5a75ecb1"}, - {file = "fonttools-4.55.1-cp313-cp313-win_amd64.whl", hash = "sha256:5d6394897710ccac7f74df48492d7f02b9586ff0588c66a2c218844e90534b22"}, - {file = "fonttools-4.55.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:52c4f4b383c56e1a4fe8dab1b63c2269ba9eab0695d2d8e033fa037e61e6f1ef"}, - {file = "fonttools-4.55.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d83892dafdbd62b56545c77b6bd4fa49eef6ec1d6b95e042ee2c930503d1831e"}, - {file = "fonttools-4.55.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604d5bf16f811fcaaaec2dde139f7ce958462487565edcd54b6fadacb2942083"}, - {file = "fonttools-4.55.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3324b92feb5fd084923a8e89a8248afd5b9f9d81ab9517d7b07cc84403bd448"}, - {file = "fonttools-4.55.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:30f8b1ca9b919c04850678d026fc330c19acaa9e3b282fcacc09a5eb3c8d20c3"}, - {file = "fonttools-4.55.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:1835c98df2cf28c86a66d234895c87df7b9325fd079a8019c5053a389ff55d23"}, - {file = "fonttools-4.55.1-cp38-cp38-win32.whl", hash = "sha256:9f202703720a7cc0049f2ed1a2047925e264384eb5cc4d34f80200d7b17f1b6a"}, - {file = "fonttools-4.55.1-cp38-cp38-win_amd64.whl", hash = "sha256:2efff20aed0338d37c2ff58766bd67f4b9607ded61cf3d6baf1b3e25ea74e119"}, - {file = "fonttools-4.55.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3032d9bf010c395e6eca2851666cafb1f4ecde85d420188555e928ad0144326e"}, - {file = "fonttools-4.55.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0794055588c30ffe25426048e8a7c0a5271942727cd61fc939391e37f4d580d5"}, - {file = "fonttools-4.55.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13ba980e3ffd3206b8c63a365f90dc10eeec27da946d5ee5373c3a325a46d77c"}, - {file = "fonttools-4.55.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d7063babd7434a17a5e355e87de9b2306c85a5c19c7da0794be15c58aab0c39"}, - {file = "fonttools-4.55.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ed84c15144015a58ef550dd6312884c9fb31a2dbc31a6467bcdafd63be7db476"}, - {file = "fonttools-4.55.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e89419d88b0bbfdb55209e03a17afa2d20db3c2fa0d785543c9d0875668195d5"}, - {file = "fonttools-4.55.1-cp39-cp39-win32.whl", hash = "sha256:6eb781e401b93cda99356bc043ababead2a5096550984d8a4ecf3d5c9f859dc2"}, - {file = "fonttools-4.55.1-cp39-cp39-win_amd64.whl", hash = "sha256:db1031acf04523c5a51c3e1ae19c21a1c32bc5f820a477dd4659a02f9cb82002"}, - {file = "fonttools-4.55.1-py3-none-any.whl", hash = "sha256:4bcfb11f90f48b48c366dd638d773a52fca0d1b9e056dc01df766bf5835baa08"}, - {file = "fonttools-4.55.1.tar.gz", hash = "sha256:85bb2e985718b0df96afc659abfe194c171726054314b019dbbfed31581673c7"}, + {file = "fonttools-4.55.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bef0f8603834643b1a6419d57902f18e7d950ec1a998fb70410635c598dc1a1e"}, + {file = "fonttools-4.55.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:944228b86d472612d3b48bcc83b31c25c2271e63fdc74539adfcfa7a96d487fb"}, + {file = "fonttools-4.55.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f0e55f5da594b85f269cfbecd2f6bd3e07d0abba68870bc3f34854de4fa4678"}, + {file = "fonttools-4.55.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b1a6e576db0c83c1b91925bf1363478c4bb968dbe8433147332fb5782ce6190"}, + {file = "fonttools-4.55.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:616368b15716781bc84df5c2191dc0540137aaef56c2771eb4b89b90933f347a"}, + {file = "fonttools-4.55.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7bbae4f3915225c2c37670da68e2bf18a21206060ad31dfb95fec91ef641caa7"}, + {file = "fonttools-4.55.2-cp310-cp310-win32.whl", hash = "sha256:8b02b10648d69d67a7eb055f4d3eedf4a85deb22fb7a19fbd9acbae7c7538199"}, + {file = "fonttools-4.55.2-cp310-cp310-win_amd64.whl", hash = "sha256:bbea0ab841113ac8e8edde067e099b7288ffc6ac2dded538b131c2c0595d5f77"}, + {file = "fonttools-4.55.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d34525e8141286fa976e14806639d32294bfb38d28bbdb5f6be9f46a1cd695a6"}, + {file = "fonttools-4.55.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ecd1c2b1c2ec46bb73685bc5473c72e16ed0930ef79bc2919ccadc43a99fb16"}, + {file = "fonttools-4.55.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9008438ad59e5a8e403a62fbefef2b2ff377eb3857d90a3f2a5f4d674ff441b2"}, + {file = "fonttools-4.55.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:131591ac8d7a47043aaf29581aba755ae151d46e49d2bf49608601efd71e8b4d"}, + {file = "fonttools-4.55.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4c83381c3e3e3d9caa25527c4300543578341f21aae89e4fbbb4debdda8d82a2"}, + {file = "fonttools-4.55.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:42aca564b575252fd9954ed0d91d97a24de24289a16ce8ff74ed0bdf5ecebf11"}, + {file = "fonttools-4.55.2-cp311-cp311-win32.whl", hash = "sha256:c6457f650ebe15baa17fc06e256227f0a47f46f80f27ec5a0b00160de8dc2c13"}, + {file = "fonttools-4.55.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cfa67414d7414442a5635ff634384101c54f53bb7b0e04aa6a61b013fcce194"}, + {file = "fonttools-4.55.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:18f082445b8fe5e91c53e6184f4c1c73f3f965c8bcc614c6cd6effd573ce6c1a"}, + {file = "fonttools-4.55.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:27c0f91adbbd706e8acd1db73e3e510118e62d0ffb651864567dccc5b2339f90"}, + {file = "fonttools-4.55.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d8ccce035320d63dba0c35f52499322f5531dbe85bba1514c7cea26297e4c54"}, + {file = "fonttools-4.55.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96e126df9615df214ec7f04bebcf60076297fbc10b75c777ce58b702d7708ffb"}, + {file = "fonttools-4.55.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:508ebb42956a7a931c4092dfa2d9b4ffd4f94cea09b8211199090d2bd082506b"}, + {file = "fonttools-4.55.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c1b9de46ef7b683d50400abf9f1578eaceee271ff51c36bf4b7366f2be29f498"}, + {file = "fonttools-4.55.2-cp312-cp312-win32.whl", hash = "sha256:2df61d9fc15199cc86dad29f64dd686874a3a52dda0c2d8597d21f509f95c332"}, + {file = "fonttools-4.55.2-cp312-cp312-win_amd64.whl", hash = "sha256:d337ec087da8216a828574aa0525d869df0a2ac217a2efc1890974ddd1fbc5b9"}, + {file = "fonttools-4.55.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:10aff204e2edee1d312fa595c06f201adf8d528a3b659cfb34cd47eceaaa6a26"}, + {file = "fonttools-4.55.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:09fe922a3eff181fd07dd724cdb441fb6b9fc355fd1c0f1aa79aca60faf1fbdd"}, + {file = "fonttools-4.55.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:487e1e8b524143a799bda0169c48b44a23a6027c1bb1957d5a172a7d3a1dd704"}, + {file = "fonttools-4.55.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b1726872e09268bbedb14dc02e58b7ea31ecdd1204c6073eda4911746b44797"}, + {file = "fonttools-4.55.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6fc88cfb58b0cd7b48718c3e61dd0d0a3ee8e2c86b973342967ce09fbf1db6d4"}, + {file = "fonttools-4.55.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e857fe1859901ad8c5cab32e0eebc920adb09f413d2d73b74b677cf47b28590c"}, + {file = "fonttools-4.55.2-cp313-cp313-win32.whl", hash = "sha256:81ccd2b3a420b8050c7d9db3be0555d71662973b3ef2a1d921a2880b58957db8"}, + {file = "fonttools-4.55.2-cp313-cp313-win_amd64.whl", hash = "sha256:d559eb1744c7dcfa90ae60cb1a4b3595e898e48f4198738c321468c01180cd83"}, + {file = "fonttools-4.55.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6b5917ef79cac8300b88fd6113003fd01bbbbea2ea060a27b95d8f77cb4c65c2"}, + {file = "fonttools-4.55.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:663eba5615d6abaaf616432354eb7ce951d518e43404371bcc2b0694ef21e8d6"}, + {file = "fonttools-4.55.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:803d5cef5fc47f44f5084d154aa3d6f069bb1b60e32390c225f897fa19b0f939"}, + {file = "fonttools-4.55.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bc5f100de0173cc39102c0399bd6c3bd544bbdf224957933f10ee442d43cddd"}, + {file = "fonttools-4.55.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3d9bbc1e380fdaf04ad9eabd8e3e6a4301eaf3487940893e9fd98537ea2e283b"}, + {file = "fonttools-4.55.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:42a9afedff07b6f75aa0f39b5e49922ac764580ef3efce035ca30284b2ee65c8"}, + {file = "fonttools-4.55.2-cp38-cp38-win32.whl", hash = "sha256:f1c76f423f1a241df08f87614364dff6e0b7ce23c962c1b74bd995ec7c0dad13"}, + {file = "fonttools-4.55.2-cp38-cp38-win_amd64.whl", hash = "sha256:25062b6ca03464dd5179fc2040fb19e03391b7cc49b9cc4f879312e638605c5c"}, + {file = "fonttools-4.55.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d1100d8e665fe386a79cab59446992de881ea74d0d6c191bb988642692aa2421"}, + {file = "fonttools-4.55.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dbdc251c5e472e5ae6bc816f9b82718b8e93ff7992e7331d6cf3562b96aa268e"}, + {file = "fonttools-4.55.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0bf24d2b02dbc9376d795a63062632ff73e3e9e60c0229373f500aed7e86dd7"}, + {file = "fonttools-4.55.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4ff250ed4ff05015dfd9cf2adf7570c7a383ca80f4d9732ac484a5ed0d8453c"}, + {file = "fonttools-4.55.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:44cf2a98aa661dbdeb8c03f5e405b074e2935196780bb729888639f5276067d9"}, + {file = "fonttools-4.55.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22ef222740eb89d189bf0612eb98fbae592c61d7efeac51bfbc2a1592d469557"}, + {file = "fonttools-4.55.2-cp39-cp39-win32.whl", hash = "sha256:93f439ca27e55f585e7aaa04a74990acd983b5f2245e41d6b79f0a8b44e684d8"}, + {file = "fonttools-4.55.2-cp39-cp39-win_amd64.whl", hash = "sha256:627cf10d6f5af5bec6324c18a2670f134c29e1b7dce3fb62e8ef88baa6cba7a9"}, + {file = "fonttools-4.55.2-py3-none-any.whl", hash = "sha256:8e2d89fbe9b08d96e22c7a81ec04a4e8d8439c31223e2dc6f2f9fc8ff14bdf9f"}, + {file = "fonttools-4.55.2.tar.gz", hash = "sha256:45947e7b3f9673f91df125d375eb57b9a23f2a603f438a1aebf3171bffa7a205"}, ] [package.extras] @@ -3827,13 +3828,13 @@ setuptools = "*" [[package]] name = "gunicorn" -version = "22.0.0" +version = "23.0.0" description = "WSGI HTTP Server for UNIX" optional = false python-versions = ">=3.7" files = [ - {file = "gunicorn-22.0.0-py3-none-any.whl", hash = "sha256:350679f91b24062c86e386e198a15438d53a7a8207235a78ba1b53df4c4378d9"}, - {file = "gunicorn-22.0.0.tar.gz", hash = "sha256:4a0b436239ff76fb33f11c07a16482c521a7e09c1ce3cc293c2330afe01bec63"}, + {file = "gunicorn-23.0.0-py3-none-any.whl", hash = "sha256:ec400d38950de4dfd418cff8328b2c8faed0edb0d517d3394e457c317908ca4d"}, + {file = "gunicorn-23.0.0.tar.gz", hash = "sha256:f014447a0101dc57e294f6c18ca6b40227a4c90e9bdb586042628030cba004ec"}, ] [package.dependencies] @@ -3874,105 +3875,120 @@ hyperframe = ">=6.0,<7" [[package]] name = "hiredis" -version = "3.0.0" +version = "3.1.0" description = "Python wrapper for hiredis" optional = false python-versions = ">=3.8" files = [ - {file = "hiredis-3.0.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:4b182791c41c5eb1d9ed736f0ff81694b06937ca14b0d4dadde5dadba7ff6dae"}, - {file = "hiredis-3.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:13c275b483a052dd645eb2cb60d6380f1f5215e4c22d6207e17b86be6dd87ffa"}, - {file = "hiredis-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c1018cc7f12824506f165027eabb302735b49e63af73eb4d5450c66c88f47026"}, - {file = "hiredis-3.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83a29cc7b21b746cb6a480189e49f49b2072812c445e66a9e38d2004d496b81c"}, - {file = "hiredis-3.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e241fab6332e8fb5f14af00a4a9c6aefa22f19a336c069b7ddbf28ef8341e8d6"}, - {file = "hiredis-3.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1fb8de899f0145d6c4d5d4bd0ee88a78eb980a7ffabd51e9889251b8f58f1785"}, - {file = "hiredis-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b23291951959141173eec10f8573538e9349fa27f47a0c34323d1970bf891ee5"}, - {file = "hiredis-3.0.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e421ac9e4b5efc11705a0d5149e641d4defdc07077f748667f359e60dc904420"}, - {file = "hiredis-3.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:77c8006c12154c37691b24ff293c077300c22944018c3ff70094a33e10c1d795"}, - {file = "hiredis-3.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:41afc0d3c18b59eb50970479a9c0e5544fb4b95e3a79cf2fbaece6ddefb926fe"}, - {file = "hiredis-3.0.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:04ccae6dcd9647eae6025425ab64edb4d79fde8b9e6e115ebfabc6830170e3b2"}, - {file = "hiredis-3.0.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:fe91d62b0594db5ea7d23fc2192182b1a7b6973f628a9b8b2e0a42a2be721ac6"}, - {file = "hiredis-3.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:99516d99316062824a24d145d694f5b0d030c80da693ea6f8c4ecf71a251d8bb"}, - {file = "hiredis-3.0.0-cp310-cp310-win32.whl", hash = "sha256:562eaf820de045eb487afaa37e6293fe7eceb5b25e158b5a1974b7e40bf04543"}, - {file = "hiredis-3.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a1c81c89ed765198da27412aa21478f30d54ef69bf5e4480089d9c3f77b8f882"}, - {file = "hiredis-3.0.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:4664dedcd5933364756d7251a7ea86d60246ccf73a2e00912872dacbfcef8978"}, - {file = "hiredis-3.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:47de0bbccf4c8a9f99d82d225f7672b9dd690d8fd872007b933ef51a302c9fa6"}, - {file = "hiredis-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e43679eca508ba8240d016d8cca9d27342d70184773c15bea78a23c87a1922f1"}, - {file = "hiredis-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13c345e7278c210317e77e1934b27b61394fee0dec2e8bd47e71570900f75823"}, - {file = "hiredis-3.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00018f22f38530768b73ea86c11f47e8d4df65facd4e562bd78773bd1baef35e"}, - {file = "hiredis-3.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ea3a86405baa8eb0d3639ced6926ad03e07113de54cb00fd7510cb0db76a89d"}, - {file = "hiredis-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c073848d2b1d5561f3903879ccf4e1a70c9b1e7566c7bdcc98d082fa3e7f0a1d"}, - {file = "hiredis-3.0.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a8dffb5f5b3415a4669d25de48b617fd9d44b0bccfc4c2ab24b06406ecc9ecb"}, - {file = "hiredis-3.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:22c17c96143c2a62dfd61b13803bc5de2ac526b8768d2141c018b965d0333b66"}, - {file = "hiredis-3.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c3ece960008dab66c6b8bb3a1350764677ee7c74ccd6270aaf1b1caf9ccebb46"}, - {file = "hiredis-3.0.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f75999ae00a920f7dce6ecae76fa5e8674a3110e5a75f12c7a2c75ae1af53396"}, - {file = "hiredis-3.0.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e069967cbd5e1900aafc4b5943888f6d34937fc59bf8918a1a546cb729b4b1e4"}, - {file = "hiredis-3.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0aacc0a78e1d94d843a6d191f224a35893e6bdfeb77a4a89264155015c65f126"}, - {file = "hiredis-3.0.0-cp311-cp311-win32.whl", hash = "sha256:719c32147ba29528cb451f037bf837dcdda4ff3ddb6cdb12c4216b0973174718"}, - {file = "hiredis-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:bdc144d56333c52c853c31b4e2e52cfbdb22d3da4374c00f5f3d67c42158970f"}, - {file = "hiredis-3.0.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:484025d2eb8f6348f7876fc5a2ee742f568915039fcb31b478fd5c242bb0fe3a"}, - {file = "hiredis-3.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fcdb552ffd97151dab8e7bc3ab556dfa1512556b48a367db94b5c20253a35ee1"}, - {file = "hiredis-3.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bb6f9fd92f147ba11d338ef5c68af4fd2908739c09e51f186e1d90958c68cc1"}, - {file = "hiredis-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa86bf9a0ed339ec9e8a9a9d0ae4dccd8671625c83f9f9f2640729b15e07fbfd"}, - {file = "hiredis-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e194a0d5df9456995d8f510eab9f529213e7326af6b94770abf8f8b7952ddcaa"}, - {file = "hiredis-3.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a1df39d74ec507d79c7a82c8063eee60bf80537cdeee652f576059b9cdd15c"}, - {file = "hiredis-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f91456507427ba36fd81b2ca11053a8e112c775325acc74e993201ea912d63e9"}, - {file = "hiredis-3.0.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9862db92ef67a8a02e0d5370f07d380e14577ecb281b79720e0d7a89aedb9ee5"}, - {file = "hiredis-3.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d10fcd9e0eeab835f492832b2a6edb5940e2f1230155f33006a8dfd3bd2c94e4"}, - {file = "hiredis-3.0.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:48727d7d405d03977d01885f317328dc21d639096308de126c2c4e9950cbd3c9"}, - {file = "hiredis-3.0.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8e0bb6102ebe2efecf8a3292c6660a0e6fac98176af6de67f020bea1c2343717"}, - {file = "hiredis-3.0.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:df274e3abb4df40f4c7274dd3e587dfbb25691826c948bc98d5fead019dfb001"}, - {file = "hiredis-3.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:034925b5fb514f7b11aac38cd55b3fd7e9d3af23bd6497f3f20aa5b8ba58e232"}, - {file = "hiredis-3.0.0-cp312-cp312-win32.whl", hash = "sha256:120f2dda469b28d12ccff7c2230225162e174657b49cf4cd119db525414ae281"}, - {file = "hiredis-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:e584fe5f4e6681d8762982be055f1534e0170f6308a7a90f58d737bab12ff6a8"}, - {file = "hiredis-3.0.0-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:122171ff47d96ed8dd4bba6c0e41d8afaba3e8194949f7720431a62aa29d8895"}, - {file = "hiredis-3.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ba9fc605ac558f0de67463fb588722878641e6fa1dabcda979e8e69ff581d0bd"}, - {file = "hiredis-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a631e2990b8be23178f655cae8ac6c7422af478c420dd54e25f2e26c29e766f1"}, - {file = "hiredis-3.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63482db3fadebadc1d01ad33afa6045ebe2ea528eb77ccaabd33ee7d9c2bad48"}, - {file = "hiredis-3.0.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f669212c390eebfbe03c4e20181f5970b82c5d0a0ad1df1785f7ffbe7d61150"}, - {file = "hiredis-3.0.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a49ef161739f8018c69b371528bdb47d7342edfdee9ddc75a4d8caddf45a6e"}, - {file = "hiredis-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98a152052b8878e5e43a2e3a14075218adafc759547c98668a21e9485882696c"}, - {file = "hiredis-3.0.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50a196af0ce657fcde9bf8a0bbe1032e22c64d8fcec2bc926a35e7ff68b3a166"}, - {file = "hiredis-3.0.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f2f312eef8aafc2255e3585dcf94d5da116c43ef837db91db9ecdc1bc930072d"}, - {file = "hiredis-3.0.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:6ca41fa40fa019cde42c21add74aadd775e71458051a15a352eabeb12eb4d084"}, - {file = "hiredis-3.0.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:6eecb343c70629f5af55a8b3e53264e44fa04e155ef7989de13668a0cb102a90"}, - {file = "hiredis-3.0.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:c3fdad75e7837a475900a1d3a5cc09aa024293c3b0605155da2d42f41bc0e482"}, - {file = "hiredis-3.0.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:8854969e7480e8d61ed7549eb232d95082a743e94138d98d7222ba4e9f7ecacd"}, - {file = "hiredis-3.0.0-cp38-cp38-win32.whl", hash = "sha256:f114a6c86edbf17554672b050cce72abf489fe58d583c7921904d5f1c9691605"}, - {file = "hiredis-3.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:7d99b91e42217d7b4b63354b15b41ce960e27d216783e04c4a350224d55842a4"}, - {file = "hiredis-3.0.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:4c6efcbb5687cf8d2aedcc2c3ed4ac6feae90b8547427d417111194873b66b06"}, - {file = "hiredis-3.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:5b5cff42a522a0d81c2ae7eae5e56d0ee7365e0c4ad50c4de467d8957aff4414"}, - {file = "hiredis-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:82f794d564f4bc76b80c50b03267fe5d6589e93f08e66b7a2f674faa2fa76ebc"}, - {file = "hiredis-3.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7a4c1791d7aa7e192f60fe028ae409f18ccdd540f8b1e6aeb0df7816c77e4a4"}, - {file = "hiredis-3.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2537b2cd98192323fce4244c8edbf11f3cac548a9d633dbbb12b48702f379f4"}, - {file = "hiredis-3.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8fed69bbaa307040c62195a269f82fc3edf46b510a17abb6b30a15d7dab548df"}, - {file = "hiredis-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:869f6d5537d243080f44253491bb30aa1ec3c21754003b3bddeadedeb65842b0"}, - {file = "hiredis-3.0.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d435ae89073d7cd51e6b6bf78369c412216261c9c01662e7008ff00978153729"}, - {file = "hiredis-3.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:204b79b30a0e6be0dc2301a4d385bb61472809f09c49f400497f1cdd5a165c66"}, - {file = "hiredis-3.0.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3ea635101b739c12effd189cc19b2671c268abb03013fd1f6321ca29df3ca625"}, - {file = "hiredis-3.0.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f359175197fd833c8dd7a8c288f1516be45415bb5c939862ab60c2918e1e1943"}, - {file = "hiredis-3.0.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ac6d929cb33dd12ad3424b75725975f0a54b5b12dbff95f2a2d660c510aa106d"}, - {file = "hiredis-3.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:100431e04d25a522ef2c3b94f294c4219c4de3bfc7d557b6253296145a144c11"}, - {file = "hiredis-3.0.0-cp39-cp39-win32.whl", hash = "sha256:e1a9c14ae9573d172dc050a6f63a644457df5d01ec4d35a6a0f097f812930f83"}, - {file = "hiredis-3.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:54a6dd7b478e6eb01ce15b3bb5bf771e108c6c148315bf194eb2ab776a3cac4d"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:50da7a9edf371441dfcc56288d790985ee9840d982750580710a9789b8f4a290"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9b285ef6bf1581310b0d5e8f6ce64f790a1c40e89c660e1320b35f7515433672"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0dcfa684966f25b335072115de2f920228a3c2caf79d4bfa2b30f6e4f674a948"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a41be8af1fd78ca97bc948d789a09b730d1e7587d07ca53af05758f31f4b985d"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:038756db735e417ab36ee6fd7725ce412385ed2bd0767e8179a4755ea11b804f"}, - {file = "hiredis-3.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:fcecbd39bd42cef905c0b51c9689c39d0cc8b88b1671e7f40d4fb213423aef3a"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a131377493a59fb0f5eaeb2afd49c6540cafcfba5b0b3752bed707be9e7c4eaf"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:3d22c53f0ec5c18ecb3d92aa9420563b1c5d657d53f01356114978107b00b860"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8a91e9520fbc65a799943e5c970ffbcd67905744d8becf2e75f9f0a5e8414f0"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3dc8043959b50141df58ab4f398e8ae84c6f9e673a2c9407be65fc789138f4a6"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51b99cfac514173d7b8abdfe10338193e8a0eccdfe1870b646009d2fb7cbe4b5"}, - {file = "hiredis-3.0.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:fa1fcad89d8a41d8dc10b1e54951ec1e161deabd84ed5a2c95c3c7213bdb3514"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:898636a06d9bf575d2c594129085ad6b713414038276a4bfc5db7646b8a5be78"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:466f836dbcf86de3f9692097a7a01533dc9926986022c6617dc364a402b265c5"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23142a8af92a13fc1e3f2ca1d940df3dcf2af1d176be41fe8d89e30a837a0b60"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:793c80a3d6b0b0e8196a2d5de37a08330125668c8012922685e17aa9108c33ac"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:467d28112c7faa29b7db743f40803d927c8591e9da02b6ce3d5fadc170a542a2"}, - {file = "hiredis-3.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:dc384874a719c767b50a30750f937af18842ee5e288afba95a5a3ed703b1515a"}, - {file = "hiredis-3.0.0.tar.gz", hash = "sha256:fed8581ae26345dea1f1e0d1a96e05041a727a45e7d8d459164583e23c6ac441"}, + {file = "hiredis-3.1.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:2892db9db21f0cf7cc298d09f85d3e1f6dc4c4c24463ab67f79bc7a006d51867"}, + {file = "hiredis-3.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:93cfa6cc25ee2ceb0be81dc61eca9995160b9e16bdb7cca4a00607d57e998918"}, + {file = "hiredis-3.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2af62070aa9433802cae7be7364d5e82f76462c6a2ae34e53008b637aaa9a156"}, + {file = "hiredis-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:072c162260ebb1d892683107da22d0d5da7a1414739eae4e185cac22fe89627f"}, + {file = "hiredis-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6b232c43e89755ba332c2745ddab059c0bc1a0f01448a3a14d506f8448b1ce6"}, + {file = "hiredis-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb5316c9a65c4dde80796aa245b76011bab64eb84461a77b0a61c1bf2970bcc9"}, + {file = "hiredis-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e812a4e656bbd1c1c15c844b28259c49e26bb384837e44e8d2aa55412c91d2f7"}, + {file = "hiredis-3.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93a6c9230e5a5565847130c0e1005c8d3aa5ca681feb0ed542c4651323d32feb"}, + {file = "hiredis-3.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a5f65e89ce50a94d9490d5442a649c6116f53f216c8c14eb37cf9637956482b2"}, + {file = "hiredis-3.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9b2d6e33601c67c074c367fdccdd6033e642284e7a56adc130f18f724c378ca8"}, + {file = "hiredis-3.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:bad3b1e0c83849910f28c95953417106f539277035a4b515d1425f93947bc28f"}, + {file = "hiredis-3.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:9646de31f5994e6218311dcf216e971703dbf804c510fd3f84ddb9813c495824"}, + {file = "hiredis-3.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:59a9230f3aa38a33d09d8171400de202f575d7a38869e5ce2947829bca6fe359"}, + {file = "hiredis-3.1.0-cp310-cp310-win32.whl", hash = "sha256:0322d70f3328b97da14b6e98b18f0090a12ed8a8bf7ae20932e2eb9d1bb0aa2c"}, + {file = "hiredis-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:802474c18e878b3f9905e160a8b7df87d57885758083eda76c5978265acb41aa"}, + {file = "hiredis-3.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:c339ff4b4739b2a40da463763dd566129762f72926bca611ad9a457a9fe64abd"}, + {file = "hiredis-3.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:0ffa2552f704a45954627697a378fc2f559004e53055b82f00daf30bd4305330"}, + {file = "hiredis-3.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9acf7f0e7106f631cd618eb60ec9bbd6e43045addd5310f66ba1177209567e59"}, + {file = "hiredis-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea4f5ecf9dbea93c827486f59c606684c3496ea71c7ba9a8131932780696e61a"}, + {file = "hiredis-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39efab176fca3d5111075f6ba56cd864f18db46d858289d39360c5672e0e5c3e"}, + {file = "hiredis-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1110eae007f30e70a058d743e369c24430327cd01fd97d99519d6794a58dd587"}, + {file = "hiredis-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b390f63191bcccbb6044d4c118acdf4fa55f38e5658ac4cfd5a33a6f0c07659"}, + {file = "hiredis-3.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72a98ccc7b8ec9ce0100ecf59f45f05d2023606e8e3676b07a316d1c1c364072"}, + {file = "hiredis-3.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7c76e751fd1e2f221dec09cdc24040ee486886e943d5d7ffc256e8cf15c75e51"}, + {file = "hiredis-3.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7d3880f213b6f14e9c69ce52beffd1748eecc8669698c4782761887273b6e1bd"}, + {file = "hiredis-3.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:87c2b3fe7e7c96eba376506a76e11514e07e848f737b254e0973e4b5c3a491e9"}, + {file = "hiredis-3.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:d3cfb4089e96f8f8ee9554da93148a9261aa6612ad2cc202c1a494c7b712e31f"}, + {file = "hiredis-3.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f12018e5c5f866a1c3f7017cb2d88e5c6f9440df2281e48865a2b6c40f247f4"}, + {file = "hiredis-3.1.0-cp311-cp311-win32.whl", hash = "sha256:107b66ce977bb2dff8f2239e68344360a75d05fed3d9fa0570ac4d3020ce2396"}, + {file = "hiredis-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:8f1240bde53d3d1676f0aba61b3661560dc9a681cae24d9de33e650864029aa4"}, + {file = "hiredis-3.1.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:f7c7f89e0bc4246115754e2eda078a111282f6d6ecc6fb458557b724fe6f2aac"}, + {file = "hiredis-3.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:3dbf9163296fa45fbddcfc4c5900f10e9ddadda37117dbfb641e327e536b53e0"}, + {file = "hiredis-3.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:af46a4be0e82df470f68f35316fa16cd1e134d1c5092fc1082e1aad64cce716d"}, + {file = "hiredis-3.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc63d698c43aea500a84d8b083f830c03808b6cf3933ae4d35a27f0a3d881652"}, + {file = "hiredis-3.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:676b3d88674134bfaaf70dac181d1790b0f33b3187bfb9da9221e17e0e624f83"}, + {file = "hiredis-3.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aed10d9df1e2fb0011db2713ac64497462e9c2c0208b648c97569da772b959ca"}, + {file = "hiredis-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b5bd8adfe8742e331a94cccd782bffea251fa70d9a709e71f4510f50794d700"}, + {file = "hiredis-3.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9fc4e35b4afb0af6da55495dd0742ad32ab88150428a6ecdbb3085cbd60714e8"}, + {file = "hiredis-3.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89b83e76eb00ab0464e7b0752a3ffcb02626e742e9509bc141424a9c3202e8dc"}, + {file = "hiredis-3.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:98ebf08c907836b70a8f40e030df8ab6f174dc7f6fa765251d813e89f14069d8"}, + {file = "hiredis-3.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6c840b9cec086328f2ee2cfee0038b5d6bbb514bac7b5e579da6e346eaac056c"}, + {file = "hiredis-3.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:c5c44e9fa6f4462d0330cb5f5d46fa652512fc86b41d4d1974d0356f263e9105"}, + {file = "hiredis-3.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e665b14ab50aa175cfa306fcb00fffd4e3ff02ceb36ca6a4df00b1246d6a73c4"}, + {file = "hiredis-3.1.0-cp312-cp312-win32.whl", hash = "sha256:bd33db977ac7af97e8d035ffadb163b00546be22e5f1297b2123f5f9bf0f8a21"}, + {file = "hiredis-3.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:37aed4aa9348600145e2d019c7be27855e503ecc4906c6976ff2f3b52e3d5d97"}, + {file = "hiredis-3.1.0-cp313-cp313-macosx_10_15_universal2.whl", hash = "sha256:b87cddd8107487863fed6994de51e5594a0be267b0b19e213694e99cdd614623"}, + {file = "hiredis-3.1.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:d302deff8cb63a7feffc1844e4dafc8076e566bbf10c5aaaf0f4fe791b8a6bd0"}, + {file = "hiredis-3.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a018340c073cf88cb635b2bedff96619df2f666018c655e7911f46fa2c1c178"}, + {file = "hiredis-3.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1e8ba6414ac1ae536129e18c069f3eb497df5a74e136e3566471620a4fa5f95"}, + {file = "hiredis-3.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a86b9fef256c2beb162244791fdc025aa55f936d6358e86e2020e512fe2e4972"}, + {file = "hiredis-3.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7acdc68e29a446ad17aadaff19c981a36b3bd8c894c3520412c8a7ab1c3e0de7"}, + {file = "hiredis-3.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7e06baea05de57e1e7548064f505a6964e992674fe61b8f274afe2ac93b6371"}, + {file = "hiredis-3.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35b5fc061c8a0dbfdb440053280504d6aaa8d9726bd4d1d0e1cfcbbdf0d60b73"}, + {file = "hiredis-3.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c89d2dcb271d24c44f02264233b75d5db8c58831190fa92456a90b87fa17b748"}, + {file = "hiredis-3.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:aa36688c10a08f626fddcf68c2b1b91b0e90b070c26e550a4151a877f5c2d431"}, + {file = "hiredis-3.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f3982a9c16c1c4bc05a00b65d01ffb8d80ea1a7b6b533be2f1a769d3e989d2c0"}, + {file = "hiredis-3.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d1a6f889514ee2452300c9a06862fceedef22a2891f1c421a27b1ba52ef130b2"}, + {file = "hiredis-3.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8a45ff7915392a55d9386bb235ea1d1eb9960615f301979f02143fc20036b699"}, + {file = "hiredis-3.1.0-cp313-cp313-win32.whl", hash = "sha256:539e5bb725b62b76a5319a4e68fc7085f01349abc2316ef3df608ea0883c51d2"}, + {file = "hiredis-3.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9020fd7e58f489fda6a928c31355add0e665fd6b87b21954e675cf9943eafa32"}, + {file = "hiredis-3.1.0-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:b621a89fc29b3f4b01be6640ec81a6a94b5382bc78fecb876408d57a071e45aa"}, + {file = "hiredis-3.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:363e21fba55e1a26349dc9ca7da6b14332123879b6359bcee4a9acecb40ca33b"}, + {file = "hiredis-3.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c156156798729eadc9ab76ffee96c88b93cc1c3b493f4dd0a4341f53939194ee"}, + {file = "hiredis-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e38d8a325f9a6afac1b1c72d996d1add9e1b99696ce9410538ba5e9aa8fdba02"}, + {file = "hiredis-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3004ef7436feb7bfa61c0b36d422b8fb8c29aaa1a514c9405f0fdee5e9694dd3"}, + {file = "hiredis-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13f5b16f97d0bbd1c04ce367c49097d1214d60e11f9fee7ef2a9b54e0a6645c8"}, + {file = "hiredis-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:230dd0e77cb0f525f58a1306a7b4aaf078037fc5229110922332ca46f90821bb"}, + {file = "hiredis-3.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d968116caddd19d63120d1298e62b1bbc694db3360ed0d5df8c3a97edbc12552"}, + {file = "hiredis-3.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:511e36a6fa41d3efab3cd5cd70ac388ed825993b9e66fa3b0e47cf27a2f5ffee"}, + {file = "hiredis-3.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c5cd20804e3cb0d31e7d899d8dd091f569c33fe40d4bade670a067ab7d31c2ac"}, + {file = "hiredis-3.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:09e89e7d34cfe5ca8f7a869fca827d1af0afe8aaddb26b38c01058730edb79ad"}, + {file = "hiredis-3.1.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:570cbf31413c77fe5e7c157f2943ca4400493ddd9cf2184731cfcafc753becd7"}, + {file = "hiredis-3.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:b9b4da8162cf289781732d6a5ba01d820c42c05943fcdb7de307d03639961db3"}, + {file = "hiredis-3.1.0-cp38-cp38-win32.whl", hash = "sha256:bc117a04bcb461d3bb1b2c5b417aee3442e1e8aa33ebc800481431f4c09fe0c5"}, + {file = "hiredis-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:34f3f5f0354db2d6797a6fb08d2c036a50af62a1d919d122c1c784304ef49347"}, + {file = "hiredis-3.1.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:a26fa888025badb5563f283cc19594c215a413e905729e59a5f7cf3f46d66c32"}, + {file = "hiredis-3.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:f50763cd819d4a52a47b5966d4bb47dee34b637c5fa6402509800eee6ecb61e6"}, + {file = "hiredis-3.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b6d1c9e1fce5e0a94072667ae2bf0142b89ebbb1917d3531184e060a43f3ee11"}, + {file = "hiredis-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e38d7a56b1a79ed0bbb9e6fe376d82e3f4dcc646ae47472f2c858e19a597c112"}, + {file = "hiredis-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ef5ad8b91530e4d10a68562b0a380ea22705a60e88cecee086d7c63a38564ce"}, + {file = "hiredis-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf3d2299b054e57a9f97ca08704c2843e44f29b57dc69b76a2592ecd212efe1a"}, + {file = "hiredis-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93811d60b0f73d0f049c86f4373a3833b4a38fce374ab151074d929553eb4304"}, + {file = "hiredis-3.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e703ff860c1d83abbcf57012b309ead02b56b60e85150c6c3bfb37cbb16ebf"}, + {file = "hiredis-3.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f9ea0678806c53d96758e74c6a898f9d506a2e3367a344757f768bef9e069366"}, + {file = "hiredis-3.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:cf6844035abf47d52a1c3f4257255af3bf3b0f14d559b08eaa45885418c6c55d"}, + {file = "hiredis-3.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:7acf35cfa7ec9e1e7559c04e7095628f7d06049b5f24dcb58c1a55ef6dc689f8"}, + {file = "hiredis-3.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:b885695dce7a39b1fd9a609ed9c4cf312e53df2ec028d5a78af7a891b5fbea4d"}, + {file = "hiredis-3.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c22fa74ddd063396b19fe8445a1ae8b4190eff755d5750dda48e860a45b2ee7"}, + {file = "hiredis-3.1.0-cp39-cp39-win32.whl", hash = "sha256:0614e16339f1784df3bbd2800322e20b4127d3f3a3509f00a5562efddb2521aa"}, + {file = "hiredis-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:c2bc713ee73ab9de4a0d68b0ab0f29612342b63173714742437b977584adb2d8"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:07ab990d0835f36bf358dbb84db4541ac0a8f533128ec09af8f80a576eef2e88"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c54a88eb9d8ebc4e5eefaadbe2102a4f7499f9e413654172f40aefd25350959"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8095ef159896e5999a795b0f80e4d64281301a109e442a8d29cd750ca6bd8303"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f8ca13e2476ffd6d5be4763f5868133506ddcfa5ce54b4dac231ebdc19be6c6"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34d25aa25c10f966d5415795ed271da84605044dbf436c054966cea5442451b3"}, + {file = "hiredis-3.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4180dc5f646b426e5fa1212e1348c167ee2a864b3a70d56579163d64a847dd1e"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d92144e0cd6e6e841a6ad343e9d58631626eeb4ac96b0322649379b5d4527447"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fcb91ba42903de637b94a1b64477f381f94ad82c0742c264f9245be76a7a3cbc"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ce71a797b5bc02c51da082428c00251ed6a7a67a03acbda5fbf9e8d028725f6"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e04c7feb9467e3170cd4d5bee381775783d81bbc45d6147c1c0ce3b50dc04f9"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a31806306a60f3565c04c964d6bee0e9d4a5120e1da589e41976b53972edf635"}, + {file = "hiredis-3.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:bc51f594c2c0863ded6501642dc96701ca8bbea9ced4fa3af0a1aeda8aa634cb"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4663a319ab7d22c597b9421e5ea384fd583e044f2f1ca9a1b98d4fef8a0fea2f"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:8060fa256862b0c3de64a73ab45bc1ccf381caca464f2647af9075b200828948"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e9445b7f117a9c8c8ccad97cb44daa55ddccff3cbc9079984eac56d982ba01f"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:732cf1c5cf1324f7bf3b6086976fe62a2ca98f0bf6316f31063c2c67be8797bc"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2102a94063d878c40df92f55199637a74f535e3a0b79ceba4a00538853a21be3"}, + {file = "hiredis-3.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d968dde69e3fe903bf9ef00667669dcf04a3e096e33aaf138775106ead138bc8"}, + {file = "hiredis-3.1.0.tar.gz", hash = "sha256:51d40ac3611091020d7dea6b05ed62cb152bff595fa4f931e7b6479d777acf7c"}, ] [[package]] @@ -4316,86 +4332,87 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "jiter" -version = "0.8.0" +version = "0.8.2" description = "Fast iterable JSON parser." optional = false python-versions = ">=3.8" files = [ - {file = "jiter-0.8.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dee4eeb293ffcd2c3b31ebab684dbf7f7b71fe198f8eddcdf3a042cc6e10205a"}, - {file = "jiter-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aad1e6e9b01cf0304dcee14db03e92e0073287a6297caf5caf2e9dbfea16a924"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:504099fb7acdbe763e10690d560a25d4aee03d918d6a063f3a761d8a09fb833f"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2373487caad7fe39581f588ab5c9262fc1ade078d448626fec93f4ffba528858"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c341ecc3f9bccde952898b0c97c24f75b84b56a7e2f8bbc7c8e38cab0875a027"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e48e7a336529b9419d299b70c358d4ebf99b8f4b847ed3f1000ec9f320e8c0c"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ee157a8afd2943be690db679f82fafb8d347a8342e8b9c34863de30c538d55"}, - {file = "jiter-0.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d7dceae3549b80087f913aad4acc2a7c1e0ab7cb983effd78bdc9c41cabdcf18"}, - {file = "jiter-0.8.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e29e9ecce53d396772590438214cac4ab89776f5e60bd30601f1050b34464019"}, - {file = "jiter-0.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fa1782f22d5f92c620153133f35a9a395d3f3823374bceddd3e7032e2fdfa0b1"}, - {file = "jiter-0.8.0-cp310-none-win32.whl", hash = "sha256:f754ef13b4e4f67a3bf59fe974ef4342523801c48bf422f720bd37a02a360584"}, - {file = "jiter-0.8.0-cp310-none-win_amd64.whl", hash = "sha256:796f750b65f5d605f5e7acaccc6b051675e60c41d7ac3eab40dbd7b5b81a290f"}, - {file = "jiter-0.8.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f6f4e645efd96b4690b9b6091dbd4e0fa2885ba5c57a0305c1916b75b4f30ff6"}, - {file = "jiter-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f61cf6d93c1ade9b8245c9f14b7900feadb0b7899dbe4aa8de268b705647df81"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0396bc5cb1309c6dab085e70bb3913cdd92218315e47b44afe9eace68ee8adaa"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:62d0e42ec5dc772bd8554a304358220be5d97d721c4648b23f3a9c01ccc2cb26"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec4b711989860705733fc59fb8c41b2def97041cea656b37cf6c8ea8dee1c3f4"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859cc35bf304ab066d88f10a44a3251a9cd057fb11ec23e00be22206db878f4f"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5000195921aa293b39b9b5bc959d7fa658e7f18f938c0e52732da8e3cc70a278"}, - {file = "jiter-0.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:36050284c0abde57aba34964d3920f3d6228211b65df7187059bb7c7f143759a"}, - {file = "jiter-0.8.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a88f608e050cfe45c48d771e86ecdbf5258314c883c986d4217cc79e1fb5f689"}, - {file = "jiter-0.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:646cf4237665b2e13b4159d8f26d53f59bc9f2e6e135e3a508a2e5dd26d978c6"}, - {file = "jiter-0.8.0-cp311-none-win32.whl", hash = "sha256:21fe5b8345db1b3023052b2ade9bb4d369417827242892051244af8fae8ba231"}, - {file = "jiter-0.8.0-cp311-none-win_amd64.whl", hash = "sha256:30c2161c5493acf6b6c3c909973fb64ae863747def01cc7574f3954e0a15042c"}, - {file = "jiter-0.8.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:d91a52d8f49ada2672a4b808a0c5c25d28f320a2c9ca690e30ebd561eb5a1002"}, - {file = "jiter-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c38cf25cf7862f61410b7a49684d34eb3b5bcbd7ddaf4773eea40e0bd43de706"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6189beb5c4b3117624be6b2e84545cff7611f5855d02de2d06ff68e316182be"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e13fa849c0e30643554add089983caa82f027d69fad8f50acadcb21c462244ab"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d7765ca159d0a58e8e0f8ca972cd6d26a33bc97b4480d0d2309856763807cd28"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1b0befe7c6e9fc867d5bed21bab0131dfe27d1fa5cd52ba2bced67da33730b7d"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7d6363d4c6f1052b1d8b494eb9a72667c3ef5f80ebacfe18712728e85327000"}, - {file = "jiter-0.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a873e57009863eeac3e3969e4653f07031d6270d037d6224415074ac17e5505c"}, - {file = "jiter-0.8.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2582912473c0d9940791479fe1bf2976a34f212eb8e0a82ee9e645ac275c5d16"}, - {file = "jiter-0.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:646163201af42f55393ee6e8f6136b8df488253a6533f4230a64242ecbfe6048"}, - {file = "jiter-0.8.0-cp312-none-win32.whl", hash = "sha256:96e75c9abfbf7387cba89a324d2356d86d8897ac58c956017d062ad510832dae"}, - {file = "jiter-0.8.0-cp312-none-win_amd64.whl", hash = "sha256:ed6074552b4a32e047b52dad5ab497223721efbd0e9efe68c67749f094a092f7"}, - {file = "jiter-0.8.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:dd5e351cb9b3e676ec3360a85ea96def515ad2b83c8ae3a251ce84985a2c9a6f"}, - {file = "jiter-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ba9f12b0f801ecd5ed0cec29041dc425d1050922b434314c592fc30d51022467"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7ba461c3681728d556392e8ae56fb44a550155a24905f01982317b367c21dd4"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a15ed47ab09576db560dbc5c2c5a64477535beb056cd7d997d5dd0f2798770e"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cef55042816d0737142b0ec056c0356a5f681fb8d6aa8499b158e87098f4c6f8"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:549f170215adeb5e866f10617c3d019d8eb4e6d4e3c6b724b3b8c056514a3487"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f867edeb279d22020877640d2ea728de5817378c60a51be8af731a8a8f525306"}, - {file = "jiter-0.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aef8845f463093799db4464cee2aa59d61aa8edcb3762aaa4aacbec3f478c929"}, - {file = "jiter-0.8.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:d0d6e22e4062c3d3c1bf3594baa2f67fc9dcdda8275abad99e468e0c6540bc54"}, - {file = "jiter-0.8.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:079e62e64696241ac3f408e337aaac09137ed760ccf2b72b1094b48745c13641"}, - {file = "jiter-0.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74d2b56ed3da5760544df53b5f5c39782e68efb64dc3aa0bba4cc08815e6fae8"}, - {file = "jiter-0.8.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:798dafe108cba58a7bb0a50d4d5971f98bb7f3c974e1373e750de6eb21c1a329"}, - {file = "jiter-0.8.0-cp313-none-win32.whl", hash = "sha256:ca6d3064dfc743eb0d3d7539d89d4ba886957c717567adc72744341c1e3573c9"}, - {file = "jiter-0.8.0-cp313-none-win_amd64.whl", hash = "sha256:38caedda64fe1f04b06d7011fc15e86b3b837ed5088657bf778656551e3cd8f9"}, - {file = "jiter-0.8.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:bb5c8a0a8d081c338db22e5b8d53a89a121790569cbb85f7d3cfb1fe0fbe9836"}, - {file = "jiter-0.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:202dbe8970bfb166fab950eaab8f829c505730a0b33cc5e1cfb0a1c9dd56b2f9"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9046812e5671fdcfb9ae02881fff1f6a14d484b7e8b3316179a372cdfa1e8026"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e6ac56425023e52d65150918ae25480d0a1ce2a6bf5ea2097f66a2cc50f6d692"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7dfcf97210c6eab9d2a1c6af15dd39e1d5154b96a7145d0a97fa1df865b7b834"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4e3c8444d418686f78c9a547b9b90031faf72a0a1a46bfec7fb31edbd889c0d"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6507011a299b7f578559084256405a8428875540d8d13530e00b688e41b09493"}, - {file = "jiter-0.8.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0aae4738eafdd34f0f25c2d3668ce9e8fa0d7cb75a2efae543c9a69aebc37323"}, - {file = "jiter-0.8.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f5d782e790396b13f2a7b36bdcaa3736a33293bdda80a4bf1a3ce0cd5ef9f15"}, - {file = "jiter-0.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cc7f993bc2c4e03015445adbb16790c303282fce2e8d9dc3a3905b1d40e50564"}, - {file = "jiter-0.8.0-cp38-none-win32.whl", hash = "sha256:d4a8a6eda018a991fa58ef707dd51524055d11f5acb2f516d70b1be1d15ab39c"}, - {file = "jiter-0.8.0-cp38-none-win_amd64.whl", hash = "sha256:4cca948a3eda8ea24ed98acb0ee19dc755b6ad2e570ec85e1527d5167f91ff67"}, - {file = "jiter-0.8.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ef89663678d8257063ce7c00d94638e05bd72f662c5e1eb0e07a172e6c1a9a9f"}, - {file = "jiter-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c402ddcba90b4cc71db3216e8330f4db36e0da2c78cf1d8a9c3ed8f272602a94"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a6dfe795b7a173a9f8ba7421cdd92193d60c1c973bbc50dc3758a9ad0fa5eb6"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8ec29a31b9abd6be39453a2c45da067138a3005d65d2c0507c530e0f1fdcd9a4"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a488f8c54bddc3ddefaf3bfd6de4a52c97fc265d77bc2dcc6ee540c17e8c342"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aeb5561adf4d26ca0d01b5811b4d7b56a8986699a473d700757b4758ef787883"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ab961858d7ad13132328517d29f121ae1b2d94502191d6bcf96bddcc8bb5d1c"}, - {file = "jiter-0.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a207e718d114d23acf0850a2174d290f42763d955030d9924ffa4227dbd0018f"}, - {file = "jiter-0.8.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:733bc9dc8ff718a0ae4695239e9268eb93e88b73b367dfac3ec227d8ce2f1e77"}, - {file = "jiter-0.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1ec27299e22d05e13a06e460bf7f75f26f9aaa0e0fb7d060f40e88df1d81faa"}, - {file = "jiter-0.8.0-cp39-none-win32.whl", hash = "sha256:e8dbfcb46553e6661d3fc1f33831598fcddf73d0f67834bce9fc3e9ebfe5c439"}, - {file = "jiter-0.8.0-cp39-none-win_amd64.whl", hash = "sha256:af2ce2487b3a93747e2cb5150081d4ae1e5874fce5924fc1a12e9e768e489ad8"}, - {file = "jiter-0.8.0.tar.gz", hash = "sha256:86fee98b569d4cc511ff2e3ec131354fafebd9348a487549c31ad371ae730310"}, + {file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"}, + {file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"}, + {file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"}, + {file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"}, + {file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"}, + {file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"}, + {file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"}, + {file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"}, + {file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"}, + {file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"}, + {file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"}, + {file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"}, + {file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"}, + {file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"}, + {file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"}, + {file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"}, + {file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"}, + {file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"}, ] [[package]] @@ -4787,13 +4804,13 @@ files = [ [[package]] name = "loguru" -version = "0.7.2" +version = "0.7.3" description = "Python logging made (stupidly) simple" optional = false -python-versions = ">=3.5" +python-versions = "<4.0,>=3.5" files = [ - {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, - {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, + {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, + {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, ] [package.dependencies] @@ -4801,7 +4818,7 @@ colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] -dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] +dev = ["Sphinx (==8.1.3)", "build (==1.2.2)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.5.0)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.13.0)", "mypy (==v1.4.1)", "myst-parser (==4.0.0)", "pre-commit (==4.0.1)", "pytest (==6.1.2)", "pytest (==8.3.2)", "pytest-cov (==2.12.1)", "pytest-cov (==5.0.0)", "pytest-cov (==6.0.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.1.0)", "sphinx-rtd-theme (==3.0.2)", "tox (==3.27.1)", "tox (==4.23.2)", "twine (==6.0.1)"] [[package]] name = "lxml" @@ -5026,13 +5043,13 @@ urllib3 = ">=1.23" [[package]] name = "mako" -version = "1.3.7" +version = "1.3.8" description = "A super-fast templating language that borrows the best ideas from the existing templating languages." optional = false python-versions = ">=3.8" files = [ - {file = "Mako-1.3.7-py3-none-any.whl", hash = "sha256:d18f990ad57f800ce8e76cbfb0b74afe471c293517e9f5003ace6dad5aa72c36"}, - {file = "mako-1.3.7.tar.gz", hash = "sha256:20405b1232e0759f0e7d87b01f6bb94fce0761747f1cb876ecf90bd512d0b639"}, + {file = "Mako-1.3.8-py3-none-any.whl", hash = "sha256:42f48953c7eb91332040ff567eb7eea69b22e7a4affbc5ba8e845e8f730f6627"}, + {file = "mako-1.3.8.tar.gz", hash = "sha256:577b97e414580d3e088d47c2dbbe9594aa7a5146ed2875d4dfa9075af2dd3cc8"}, ] [package.dependencies] @@ -5626,6 +5643,58 @@ files = [ {file = "multitasking-0.0.11.tar.gz", hash = "sha256:4d6bc3cc65f9b2dca72fb5a787850a88dae8f620c2b36ae9b55248e51bcd6026"}, ] +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -6048,6 +6117,27 @@ files = [ [package.dependencies] opencensus = ">=0.8.0,<1.0.0" +[[package]] +name = "opendal" +version = "0.45.12" +description = "Apache OpenDAL™ Python Binding" +optional = false +python-versions = ">=3.11" +files = [ + {file = "opendal-0.45.12-cp311-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:fd6551780194870867ed205135d5e7e2d411145d3cc4faa63830f54bbf48acdb"}, + {file = "opendal-0.45.12-cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6fb9dc5021c5a62785fdf81a2d6ab97b65b8ef86ccded119fe242a10655263"}, + {file = "opendal-0.45.12-cp311-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6803edda7c0936722ecc5c2cf01fd84dcb520f11e1643f285605451df6b7c20b"}, + {file = "opendal-0.45.12-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:389908b68845991f5bf4e75fbf1b415f14b02fab3201d01f5b3a9ae0030ee164"}, + {file = "opendal-0.45.12-cp311-abi3-win_amd64.whl", hash = "sha256:c87f488c547e17174d53f98da1c25135595bf93c29bab731d732f55d993534e0"}, + {file = "opendal-0.45.12.tar.gz", hash = "sha256:5b35a1abf6a30a6dc82e343a5c8403f245c89125cc037c2b89ed7803409c717c"}, +] + +[package.extras] +benchmark = ["boto3", "boto3-stubs[essential]", "gevent", "greenify", "greenlet", "pydantic"] +docs = ["pdoc"] +lint = ["ruff"] +test = ["pytest", "pytest-asyncio", "python-dotenv"] + [[package]] name = "openpyxl" version = "3.1.5" @@ -6499,6 +6589,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pathos" version = "0.3.3" @@ -6795,20 +6900,20 @@ dill = ["dill (>=0.3.9)"] [[package]] name = "primp" -version = "0.8.1" +version = "0.8.2" description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" optional = false python-versions = ">=3.8" files = [ - {file = "primp-0.8.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:8294db817701ad76b6a186c16e22cc49d36fac5986647a83657ad4a58ddeee42"}, - {file = "primp-0.8.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e8117531dcdb0dbcf9855fdbac73febdde5967ca0332a2c05b5961d2fbcfe749"}, - {file = "primp-0.8.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:993cc4284e8c5c858254748f078e872ba250c9339d64398dc000a8f9cffadda3"}, - {file = "primp-0.8.1-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4a27ac642be5c616fc5f139a5ad391dcd0c5964ace56fe6cf31cbffb972a7480"}, - {file = "primp-0.8.1-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:e8483b8d9eec9fc43d77bb448555466030f29cdd99d9375eb75155e9f832e5bd"}, - {file = "primp-0.8.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:92f5f8267216252cfb27f2149811e14682bb64f0c5d37f00d218d1592e02f0b9"}, - {file = "primp-0.8.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:98f7f3a9481c55c56e7eff9024f29e16379a87d5b0a1b683e145dd8fcbdcc46b"}, - {file = "primp-0.8.1-cp38-abi3-win_amd64.whl", hash = "sha256:6f0018a26be787431504e32548b296a278abbe85da43bcbaf2d4982ac3dcd332"}, - {file = "primp-0.8.1.tar.gz", hash = "sha256:ddf05754a7b70d59df8a014a8585e418f9c04e0b69065bab6633f4a9b92bad93"}, + {file = "primp-0.8.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:20c4988c6538dfcac804e804f286493696e53498d5705e745a36d9fe436c787c"}, + {file = "primp-0.8.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:dde74d6bf5534a60fd075e81b5828a6591753a647c5bfe69e664883e5c7a28bb"}, + {file = "primp-0.8.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f988d7e47d7f63b63f851885d51abd86ba3a2a1981d047466c1e63827753a168"}, + {file = "primp-0.8.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:965cf0c19986d074d4e20ce18f1b81e5c31818324718814af6317a291a3aba65"}, + {file = "primp-0.8.2-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:afc56989ae09bed76105bf045e666ea2da5f32e2e93dfb967795a4da4fc777e5"}, + {file = "primp-0.8.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:64e8b9b216ee0f52d2885ac23303000339f798a59eb9b4b3b747dcbbf9187beb"}, + {file = "primp-0.8.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b65de6d8fe4c7ef9d5d508e2a9cee3da77455e3a44c9282bdebb2134c55087c9"}, + {file = "primp-0.8.2-cp38-abi3-win_amd64.whl", hash = "sha256:d686cf4ce21c318bafe2f0574aec9f7f9526d18a4b0c017f507bd007f323e519"}, + {file = "primp-0.8.2.tar.gz", hash = "sha256:572ecd34b77021a89a0574b66b07e1da100afd6ec490d3b519a6763fac6ae6c5"}, ] [package.extras] @@ -7407,13 +7512,13 @@ rsa = ["cryptography"] [[package]] name = "pyobvector" -version = "0.1.16" +version = "0.1.17" description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API." optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "pyobvector-0.1.16-py3-none-any.whl", hash = "sha256:d2ec2f974f0a32b65fa1558a39e7cb36d8e14b2192a7d603990f183c5cae79d7"}, - {file = "pyobvector-0.1.16.tar.gz", hash = "sha256:2d0fdd90d85cdfc8dc1d7b6950cd4fbb160a0696065c7d6ebdf70d09745896c5"}, + {file = "pyobvector-0.1.17-py3-none-any.whl", hash = "sha256:faf73d14ded736f21f2ce9d92d0964de9d477afeacfbf6d41db0b5b18495aadd"}, + {file = "pyobvector-0.1.17.tar.gz", hash = "sha256:bfc89f8de88806b63d64d7dfc15e5f9890243387ba53cc69247de52a46045d5a"}, ] [package.dependencies] @@ -7486,23 +7591,24 @@ image = ["Pillow (>=8.0.0)"] [[package]] name = "pypdfium2" -version = "4.17.0" +version = "4.30.0" description = "Python bindings to PDFium" optional = false python-versions = ">=3.6" files = [ - {file = "pypdfium2-4.17.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:e9ed42d5a5065ae41ae3ead3cd642e1f21b6039e69ccc204e260e218e91cd7e1"}, - {file = "pypdfium2-4.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0a3b5a8eca53a1e68434969821b70bd2bc9ac2b70e58daf516c6ff0b6b5779e7"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:854e04b51205466ec415b86588fe5dc593e9ca3e8e15b5aa05978c5352bd57d2"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_armv7l.whl", hash = "sha256:9ff8707b28568e9585bdf9a96b7a8a9f91c0b5ad05af119b49381dad89983364"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_i686.whl", hash = "sha256:09ecbef6212993db0b5460cfd46d6b157a921ff45c97b0764e6fe8ea2e8cdebf"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:f680e469b79c71c3fb086d7ced8361fbd66f4cd7b0ad08ff888289fe6743ab32"}, - {file = "pypdfium2-4.17.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ba7a7da48fbf0f1aaa903dac7d0e62186d6e8ae9a78b7b7b836d3f1b3d1be5d"}, - {file = "pypdfium2-4.17.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:451752170caf59d4b4572b527c2858dfff96eb1da35f2822c66cdce006dd4eae"}, - {file = "pypdfium2-4.17.0-py3-none-win32.whl", hash = "sha256:4930cfa793298214fa644c6986f6466e21f98eba3f338b4577614ebd8aa34af5"}, - {file = "pypdfium2-4.17.0-py3-none-win_amd64.whl", hash = "sha256:99de7f336e967dea4d324484f581fff55db1eb3c8e90baa845567dd9a3cc84f3"}, - {file = "pypdfium2-4.17.0-py3-none-win_arm64.whl", hash = "sha256:9381677b489c13d64ea4f8cbf6ebfc858216b052883e01e40fa993c2818a078e"}, - {file = "pypdfium2-4.17.0.tar.gz", hash = "sha256:2a2b3273c4614ee2004df60ace5f387645f843418ae29f379408ee11560241c0"}, + {file = "pypdfium2-4.30.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b33ceded0b6ff5b2b93bc1fe0ad4b71aa6b7e7bd5875f1ca0cdfb6ba6ac01aab"}, + {file = "pypdfium2-4.30.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4e55689f4b06e2d2406203e771f78789bd4f190731b5d57383d05cf611d829de"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6e50f5ce7f65a40a33d7c9edc39f23140c57e37144c2d6d9e9262a2a854854"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3d0dd3ecaffd0b6dbda3da663220e705cb563918249bda26058c6036752ba3a2"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc3bf29b0db8c76cdfaac1ec1cde8edf211a7de7390fbf8934ad2aa9b4d6dfad"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1f78d2189e0ddf9ac2b7a9b9bd4f0c66f54d1389ff6c17e9fd9dc034d06eb3f"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5eda3641a2da7a7a0b2f4dbd71d706401a656fea521b6b6faa0675b15d31a163"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:0dfa61421b5eb68e1188b0b2231e7ba35735aef2d867d86e48ee6cab6975195e"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f33bd79e7a09d5f7acca3b0b69ff6c8a488869a7fab48fdf400fec6e20b9c8be"}, + {file = "pypdfium2-4.30.0-py3-none-win32.whl", hash = "sha256:ee2410f15d576d976c2ab2558c93d392a25fb9f6635e8dd0a8a3a5241b275e0e"}, + {file = "pypdfium2-4.30.0-py3-none-win_amd64.whl", hash = "sha256:90dbb2ac07be53219f56be09961eb95cf2473f834d01a42d901d13ccfad64b4c"}, + {file = "pypdfium2-4.30.0-py3-none-win_arm64.whl", hash = "sha256:119b2969a6d6b1e8d55e99caaf05290294f2d0fe49c12a3f17102d01c441bd29"}, + {file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"}, ] [[package]] @@ -8427,114 +8533,114 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.22.1" +version = "0.22.3" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.9" files = [ - {file = "rpds_py-0.22.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ab27dd4edd84b13309f268ffcdfc07aef8339135ffab7b6d43f16884307a2a48"}, - {file = "rpds_py-0.22.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9d5b925156a746dc1f5f52376fdd1fbdd3f6ffe1fcd6f5e06f77ca79abb940a3"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201650b309c419143775c15209c620627de3c09a27c7fb58375325aec5cce260"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:31264187fc934ff1024a4f56775f33c9252d3f4f3e27ec07d1995a26b52702c3"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97c5ffe47ccf92d8b17e10f8a5ce28d015aa1196edc3359684cf31504eae6a14"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9ac7280bd045f472b50306d7efeee051b69e3a2dd1b90f46bd7e86e63b1efa2"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f941fb86195f97be7f6efe04a21b223f05dfe4d1dfb159999e2f8d101e44cc4"}, - {file = "rpds_py-0.22.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f91bfc39f7a64168e08ab831fa497ec5438c1d6c6e2f9e12848d95ad11ac8523"}, - {file = "rpds_py-0.22.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:effcae2152afe7937a28376dbabb25c770ef99ed4e16a4ffeb8e6a4f7c4f06aa"}, - {file = "rpds_py-0.22.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:2177e59c033bf0d1bf7de1ced561205963583caf3242c6c700a723034bfb5f8e"}, - {file = "rpds_py-0.22.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:66f4f48a89cdd30ab3a47335df81c76e9a63799d0d84b29c0618371c66fa37b0"}, - {file = "rpds_py-0.22.1-cp310-cp310-win32.whl", hash = "sha256:b07fa9e634234e84096adfa4be3828c8f26e238679c122824b2b3d7131bec578"}, - {file = "rpds_py-0.22.1-cp310-cp310-win_amd64.whl", hash = "sha256:ca4657e9fd0b1b5376942d403d634ce188f79064f0873aa853ab05b10185ceec"}, - {file = "rpds_py-0.22.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:608c84699b2db09c6a8743845b1a3dad36fae53eaaecb241d45b13dff74405fb"}, - {file = "rpds_py-0.22.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9dae4eb9b5534e09ba6c6ab496a757e5e394b7e7b08767d25ca37e8d36491114"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09a1f000c5f6e08b298275bae00921e9fbbf2a35dae0a86db2821c058c2201a9"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:580ccbf11f02f948add4cb641843030a89f1463d7c0740cbfc9aca91e9dc34b3"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96559e05bdf938b2048353e10a7920b98f853cefe4482c2064a718d7d0a50bd7"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:128cbaed7ba26116820bcb992405d6a13ea18c8fca1b8c4f59906d858e91e979"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:734783dd7da58f76222f458346ddebdb3621686a1a2a667db5049caf0c9956b9"}, - {file = "rpds_py-0.22.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c9ce6b83597d45bec44a2690857ede62fc98223772135f8a7fa90884eb726501"}, - {file = "rpds_py-0.22.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:bca4428c4a957b78ded3e6e62884ab03f029dce8fa8d34818da0f80f61332b49"}, - {file = "rpds_py-0.22.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1ded65691a1d3fd7d2aa89d2c91aa51f941601bb2ce099739909034d957fef4b"}, - {file = "rpds_py-0.22.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:72407065ad459db9f3d052ea8c51e02534f02533fc61e51cbab3bd94166f086c"}, - {file = "rpds_py-0.22.1-cp311-cp311-win32.whl", hash = "sha256:eb013aa01b404219f28dc973d9e6310fd4db216d7299253dd355629952e0564e"}, - {file = "rpds_py-0.22.1-cp311-cp311-win_amd64.whl", hash = "sha256:8bd9ec1db79a664f4cbb12878693b73416f4d2cb425d3e27eccc1bdfbdc826ef"}, - {file = "rpds_py-0.22.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:8ec41049c90d204a6561238a9ad6c7263ebb7009d9759c98b58078d9d2fec9ba"}, - {file = "rpds_py-0.22.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:102be79c4cc47a4aeb5912401185c404cd2601c15a7163bbecff7f1bfe20b669"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a603155db408f773637f9e3a712c6e3cbc521aaa8fa2b99f9ba6106c59a2496"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5dbff9402c2bdf00bf0df9905694b3c292a3847c725651938a72f554351a5fcb"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96b3759d8ab2323324e0a92b2f44834f9d88089b8d1ab6f533b61f4be3411cef"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3029f481b31f329b1fdb4ec4b56935d82210ddd9c6f86ea5a87c06f1e97b161"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d280b4bf09f719b89fd9aab3b71067acc0d0449b7d1eba99a2ade4939cef8296"}, - {file = "rpds_py-0.22.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c8e97e19aa7b0b0d801a159f932ce4435f1049c8c38e2bb372bb5bee559ce50"}, - {file = "rpds_py-0.22.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:50e4b5d291105f7063259fe0125b1af902fb34499444d7c5c521dd8328b00939"}, - {file = "rpds_py-0.22.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d3777c446bb1c5fcd82dc3f8776e1a146cd91e80cc1892f8634575ace438d22f"}, - {file = "rpds_py-0.22.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:447ae1104fb32197b9262f772d565d38e834cc2e9edd89350b37b88fed636e70"}, - {file = "rpds_py-0.22.1-cp312-cp312-win32.whl", hash = "sha256:55d371b9d8b0c2a68a50413a8cb01c3c3ce1ea4f768bf77b66669a9a486e101e"}, - {file = "rpds_py-0.22.1-cp312-cp312-win_amd64.whl", hash = "sha256:413a30a99d8683dace3765885920ed27ab662efbb6c98d81db76c397ad1ffd71"}, - {file = "rpds_py-0.22.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa2ba0176037c915d8660a4e46581d645e2c22b5373e466bc8640a794d45861a"}, - {file = "rpds_py-0.22.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ba6c66fbc6015b2f99e7176fec41793cecb00c4cc357cad038dff85e6ac42ab"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15fa4ca658f8ad22645d3531682b17e5580832efbfa87304c3e62214c79c1e8a"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d7833ef6f5d6cb634f296abfd93452fb3eb44c4e9a6ae95c1021eab704c1cee2"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c0467838c90435b80793cde486a318fc916ee57f2af54e4b10c72b20cbdcbaa9"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d962e2e89b3a95e3597a34b8c93ced1e98958502c5b8096c9fd69deff279f561"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ce729f1dc8a4a190c34b69f75377bddc004079b2963ab722ab91fafe040be6d"}, - {file = "rpds_py-0.22.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8080467df22feca0fc9c46567001777c6fbc2b4a2683a7137420896051874ca1"}, - {file = "rpds_py-0.22.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0f9eb37d3a60b262a98ab51ee899cac039de9ca0ce68dcf1a6518a09719020b0"}, - {file = "rpds_py-0.22.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:153248f48d6f90a295a502f53ec544a3ffbd21b0bb32f5dca39c4b93a764d6a2"}, - {file = "rpds_py-0.22.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0a53592cdf98cec3dfcdb24ffec8a4797e7656b65700099af43ec7df023b6de4"}, - {file = "rpds_py-0.22.1-cp313-cp313-win32.whl", hash = "sha256:e8056adcefa2dcb67e8bc91ea5eee26df66e8b297a8cd6ff0903f85c70908fa0"}, - {file = "rpds_py-0.22.1-cp313-cp313-win_amd64.whl", hash = "sha256:a451dba533be77454ebcffc85189108fc05f279100835ac76e7989edacb89156"}, - {file = "rpds_py-0.22.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:2ea23f1525d4f64286dbe0947c929d45c3ffe963b2dbed1d3844a2e4938bda42"}, - {file = "rpds_py-0.22.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3aaa22487477de9618ce3b37f99fbe81219ba96f3c2ca84f576f0ab451b83aba"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8954b9ffe60f479a0c0ba40987db2546c735ab02a725ea7fd89342152d4d821d"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c8502a02ae3ae67084f5a0bf5a8253b19fa7a887f824e41e016cdb0ac532a06f"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a083221b6a4ecdef38a60c95d8d3223d99449cb4da2544e9644958dc16664eb9"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:542eb246d5be31b5e0a9c8ddb9539416f9b31f58f75bd4ee328bff2b5c58d6fd"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffae97d28ea4f2c613a751d087b75a97fb78311b38cc2e9a2f4587e473ace167"}, - {file = "rpds_py-0.22.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0ff8d5b13ce2357fa8b33a0a2e3775aa71df5bf7c8ba060634c9d15ab12f357"}, - {file = "rpds_py-0.22.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:0f057a0c546c42964836b209d8de9ea1a4f4b0432006c6343cbe633d8ca14571"}, - {file = "rpds_py-0.22.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:48ee97c7c6027fd423058675b5a39d0b5f7a1648250b671563d5c9f74ff13ff0"}, - {file = "rpds_py-0.22.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:babec324e8654a59122aaa66936a9a483faa03276db9792f51332475c2dddc4a"}, - {file = "rpds_py-0.22.1-cp313-cp313t-win32.whl", hash = "sha256:e69acdbc132c9592c8dc393af85e38e206ca847c7019a953ff625191c3a12312"}, - {file = "rpds_py-0.22.1-cp313-cp313t-win_amd64.whl", hash = "sha256:c783e4ed68200f4e03c125690d23158b1c49c4b186d458a18debc109bbdc3c2e"}, - {file = "rpds_py-0.22.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:2143c3aed85992604d758bbe67da839fb4aab3dd2e1c6dddab5b3ca7162b34a2"}, - {file = "rpds_py-0.22.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f57e2d0f8022783426121b586d7c842ea40ea832a29e28ca36c881b54c74fb28"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c0c324879d483504b07f7b18eb1b50567c434263bbe4866ecce33056162668a"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1c40e02cc4f3e18fd39344edb10eebe04bd11cfd13119606b5771e5ea51630d3"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f76c6f319e57007ad52e671ec741d801324760a377e3d4992c9bb8200333ebac"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f5cae9b415ea8a6a563566dbf46650222eccc5971c7daa16fbee63aef92ae543"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b09209cdfcacf5eba9cf80367130532e6c02e695252e1f64d3cfcc2356e6e19f"}, - {file = "rpds_py-0.22.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dbe428d0ac6eacaf05402adbaf137f59ad6063848182d1ff294f95ce0f24005b"}, - {file = "rpds_py-0.22.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:626b9feb01bff049a5aec4804f0c58db12585778b4902e5376a95b01f80a7a16"}, - {file = "rpds_py-0.22.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec1ccc2a9f764cd632fb8ab28fdde166250df54fc8d97315a4a6948dc5367639"}, - {file = "rpds_py-0.22.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ef92b1fbe6aa2e7885eb90853cc016b1fc95439a8cc8da6d526880e9e2148695"}, - {file = "rpds_py-0.22.1-cp39-cp39-win32.whl", hash = "sha256:c88535f83f7391cf3a45af990237e3939a6fdfbedaed2571633bfdd0bceb36b0"}, - {file = "rpds_py-0.22.1-cp39-cp39-win_amd64.whl", hash = "sha256:7839b7528faa4d134c183b1f2dd1ee4dc2ca2f899f4f0cfdf00fc04c255262a7"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a0ed14a4162c2c2b21a162c9fcf90057e3e7da18cd171ab344c1e1664f75090e"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:05fdeae9010533e47715c37df83264df0122584e40d691d50cf3607c060952a3"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4659b2e4a5008715099e216050f5c6976e5a4329482664411789968b82e3f17d"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a18aedc032d6468b73ebbe4437129cb30d54fe543cde2f23671ecad76c3aea24"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149b4d875ef9b12a8f5e303e86a32a58f8ef627e57ec97a7d0e4be819069d141"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdaee3947eaaa52dae3ceb9d9f66329e13d8bae35682b1e5dd54612938693934"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36ce951800ed2acc6772fd9f42150f29d567f0423989748052fdb39d9e2b5795"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ab784621d3e2a41916e21f13a483602cc989fd45fff637634b9231ba43d4383b"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:c2a214bf5b79bd39a9de1c991353aaaacafda83ba1374178309e92be8e67d411"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:85060e96953647871957d41707adb8d7bff4e977042fd0deb4fc1881b98dd2fe"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c6f3fd617db422c9d4e12cb8d84c984fe07d6d9cb0950cbf117f3bccc6268d05"}, - {file = "rpds_py-0.22.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f2d1b58a0c3a73f0361759642e80260a6d28eee6501b40fe25b82af33ef83f21"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:76eaa4c087a061a2c8a0a92536405069878a8f530c00e84a9eaf332e70f5561f"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:959ae04ed30cde606f3a0320f0a1f4167a107e685ef5209cce28c5080590bd31"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:198067aa6f3d942ff5d0d655bb1e91b59ae85279d47590682cba2834ac1b97d2"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3e7e99e2af59c56c59b6c964d612511b8203480d39d1ef83edc56f2cb42a3f5d"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0545928bdf53dfdfcab284468212efefb8a6608ca3b6910c7fb2e5ed8bdc2dc0"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef7282d8a14b60dd515e47060638687710b1d518f4b5e961caad43fb3a3606f9"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe3f245c2f39a5692d9123c174bc48f6f9fe3e96407e67c6d04541a767d99e72"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efb2ad60ca8637d5f9f653f9a9a8d73964059972b6b95036be77e028bffc68a3"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:d8306f27418361b788e3fca9f47dec125457f80122e7e31ba7ff5cdba98343f8"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:4c8dc7331e8cbb1c0ea2bcb550adb1777365944ffd125c69aa1117fdef4887f5"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:776a06cb5720556a549829896a49acebb5bdd96c7bba100191a994053546975a"}, - {file = "rpds_py-0.22.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e4f91d702b9ce1388660b3d4a28aa552614a1399e93f718ed0dacd68f23b3d32"}, - {file = "rpds_py-0.22.1.tar.gz", hash = "sha256:157a023bded0618a1eea54979fe2e0f9309e9ddc818ef4b8fc3b884ff38fedd5"}, + {file = "rpds_py-0.22.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6c7b99ca52c2c1752b544e310101b98a659b720b21db00e65edca34483259967"}, + {file = "rpds_py-0.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be2eb3f2495ba669d2a985f9b426c1797b7d48d6963899276d22f23e33d47e37"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70eb60b3ae9245ddea20f8a4190bd79c705a22f8028aaf8bbdebe4716c3fab24"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4041711832360a9b75cfb11b25a6a97c8fb49c07b8bd43d0d02b45d0b499a4ff"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64607d4cbf1b7e3c3c8a14948b99345eda0e161b852e122c6bb71aab6d1d798c"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e69b0a0e2537f26d73b4e43ad7bc8c8efb39621639b4434b76a3de50c6966e"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc27863442d388870c1809a87507727b799c8460573cfbb6dc0eeaef5a11b5ec"}, + {file = "rpds_py-0.22.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e79dd39f1e8c3504be0607e5fc6e86bb60fe3584bec8b782578c3b0fde8d932c"}, + {file = "rpds_py-0.22.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e0fa2d4ec53dc51cf7d3bb22e0aa0143966119f42a0c3e4998293a3dd2856b09"}, + {file = "rpds_py-0.22.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fda7cb070f442bf80b642cd56483b5548e43d366fe3f39b98e67cce780cded00"}, + {file = "rpds_py-0.22.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cff63a0272fcd259dcc3be1657b07c929c466b067ceb1c20060e8d10af56f5bf"}, + {file = "rpds_py-0.22.3-cp310-cp310-win32.whl", hash = "sha256:9bd7228827ec7bb817089e2eb301d907c0d9827a9e558f22f762bb690b131652"}, + {file = "rpds_py-0.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:9beeb01d8c190d7581a4d59522cd3d4b6887040dcfc744af99aa59fef3e041a8"}, + {file = "rpds_py-0.22.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d20cfb4e099748ea39e6f7b16c91ab057989712d31761d3300d43134e26e165f"}, + {file = "rpds_py-0.22.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:68049202f67380ff9aa52f12e92b1c30115f32e6895cd7198fa2a7961621fc5a"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb4f868f712b2dd4bcc538b0a0c1f63a2b1d584c925e69a224d759e7070a12d5"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bc51abd01f08117283c5ebf64844a35144a0843ff7b2983e0648e4d3d9f10dbb"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f3cec041684de9a4684b1572fe28c7267410e02450f4561700ca5a3bc6695a2"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ef9d9da710be50ff6809fed8f1963fecdfecc8b86656cadfca3bc24289414b0"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59f4a79c19232a5774aee369a0c296712ad0e77f24e62cad53160312b1c1eaa1"}, + {file = "rpds_py-0.22.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a60bce91f81ddaac922a40bbb571a12c1070cb20ebd6d49c48e0b101d87300d"}, + {file = "rpds_py-0.22.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e89391e6d60251560f0a8f4bd32137b077a80d9b7dbe6d5cab1cd80d2746f648"}, + {file = "rpds_py-0.22.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3fb866d9932a3d7d0c82da76d816996d1667c44891bd861a0f97ba27e84fc74"}, + {file = "rpds_py-0.22.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1352ae4f7c717ae8cba93421a63373e582d19d55d2ee2cbb184344c82d2ae55a"}, + {file = "rpds_py-0.22.3-cp311-cp311-win32.whl", hash = "sha256:b0b4136a252cadfa1adb705bb81524eee47d9f6aab4f2ee4fa1e9d3cd4581f64"}, + {file = "rpds_py-0.22.3-cp311-cp311-win_amd64.whl", hash = "sha256:8bd7c8cfc0b8247c8799080fbff54e0b9619e17cdfeb0478ba7295d43f635d7c"}, + {file = "rpds_py-0.22.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:27e98004595899949bd7a7b34e91fa7c44d7a97c40fcaf1d874168bb652ec67e"}, + {file = "rpds_py-0.22.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1978d0021e943aae58b9b0b196fb4895a25cc53d3956b8e35e0b7682eefb6d56"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:655ca44a831ecb238d124e0402d98f6212ac527a0ba6c55ca26f616604e60a45"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:feea821ee2a9273771bae61194004ee2fc33f8ec7db08117ef9147d4bbcbca8e"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22bebe05a9ffc70ebfa127efbc429bc26ec9e9b4ee4d15a740033efda515cf3d"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3af6e48651c4e0d2d166dc1b033b7042ea3f871504b6805ba5f4fe31581d8d38"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e67ba3c290821343c192f7eae1d8fd5999ca2dc99994114643e2f2d3e6138b15"}, + {file = "rpds_py-0.22.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:02fbb9c288ae08bcb34fb41d516d5eeb0455ac35b5512d03181d755d80810059"}, + {file = "rpds_py-0.22.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f56a6b404f74ab372da986d240e2e002769a7d7102cc73eb238a4f72eec5284e"}, + {file = "rpds_py-0.22.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0a0461200769ab3b9ab7e513f6013b7a97fdeee41c29b9db343f3c5a8e2b9e61"}, + {file = "rpds_py-0.22.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8633e471c6207a039eff6aa116e35f69f3156b3989ea3e2d755f7bc41754a4a7"}, + {file = "rpds_py-0.22.3-cp312-cp312-win32.whl", hash = "sha256:593eba61ba0c3baae5bc9be2f5232430453fb4432048de28399ca7376de9c627"}, + {file = "rpds_py-0.22.3-cp312-cp312-win_amd64.whl", hash = "sha256:d115bffdd417c6d806ea9069237a4ae02f513b778e3789a359bc5856e0404cc4"}, + {file = "rpds_py-0.22.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ea7433ce7e4bfc3a85654aeb6747babe3f66eaf9a1d0c1e7a4435bbdf27fea84"}, + {file = "rpds_py-0.22.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6dd9412824c4ce1aca56c47b0991e65bebb7ac3f4edccfd3f156150c96a7bf25"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20070c65396f7373f5df4005862fa162db5d25d56150bddd0b3e8214e8ef45b4"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b09865a9abc0ddff4e50b5ef65467cd94176bf1e0004184eb915cbc10fc05c5"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3453e8d41fe5f17d1f8e9c383a7473cd46a63661628ec58e07777c2fff7196dc"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f5d36399a1b96e1a5fdc91e0522544580dbebeb1f77f27b2b0ab25559e103b8b"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:009de23c9c9ee54bf11303a966edf4d9087cd43a6003672e6aa7def643d06518"}, + {file = "rpds_py-0.22.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1aef18820ef3e4587ebe8b3bc9ba6e55892a6d7b93bac6d29d9f631a3b4befbd"}, + {file = "rpds_py-0.22.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f60bd8423be1d9d833f230fdbccf8f57af322d96bcad6599e5a771b151398eb2"}, + {file = "rpds_py-0.22.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:62d9cfcf4948683a18a9aff0ab7e1474d407b7bab2ca03116109f8464698ab16"}, + {file = "rpds_py-0.22.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9253fc214112405f0afa7db88739294295f0e08466987f1d70e29930262b4c8f"}, + {file = "rpds_py-0.22.3-cp313-cp313-win32.whl", hash = "sha256:fb0ba113b4983beac1a2eb16faffd76cb41e176bf58c4afe3e14b9c681f702de"}, + {file = "rpds_py-0.22.3-cp313-cp313-win_amd64.whl", hash = "sha256:c58e2339def52ef6b71b8f36d13c3688ea23fa093353f3a4fee2556e62086ec9"}, + {file = "rpds_py-0.22.3-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:f82a116a1d03628a8ace4859556fb39fd1424c933341a08ea3ed6de1edb0283b"}, + {file = "rpds_py-0.22.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3dfcbc95bd7992b16f3f7ba05af8a64ca694331bd24f9157b49dadeeb287493b"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59259dc58e57b10e7e18ce02c311804c10c5a793e6568f8af4dead03264584d1"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5725dd9cc02068996d4438d397e255dcb1df776b7ceea3b9cb972bdb11260a83"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99b37292234e61325e7a5bb9689e55e48c3f5f603af88b1642666277a81f1fbd"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:27b1d3b3915a99208fee9ab092b8184c420f2905b7d7feb4aeb5e4a9c509b8a1"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f612463ac081803f243ff13cccc648578e2279295048f2a8d5eb430af2bae6e3"}, + {file = "rpds_py-0.22.3-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f73d3fef726b3243a811121de45193c0ca75f6407fe66f3f4e183c983573e130"}, + {file = "rpds_py-0.22.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3f21f0495edea7fdbaaa87e633a8689cd285f8f4af5c869f27bc8074638ad69c"}, + {file = "rpds_py-0.22.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:1e9663daaf7a63ceccbbb8e3808fe90415b0757e2abddbfc2e06c857bf8c5e2b"}, + {file = "rpds_py-0.22.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a76e42402542b1fae59798fab64432b2d015ab9d0c8c47ba7addddbaf7952333"}, + {file = "rpds_py-0.22.3-cp313-cp313t-win32.whl", hash = "sha256:69803198097467ee7282750acb507fba35ca22cc3b85f16cf45fb01cb9097730"}, + {file = "rpds_py-0.22.3-cp313-cp313t-win_amd64.whl", hash = "sha256:f5cf2a0c2bdadf3791b5c205d55a37a54025c6e18a71c71f82bb536cf9a454bf"}, + {file = "rpds_py-0.22.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:378753b4a4de2a7b34063d6f95ae81bfa7b15f2c1a04a9518e8644e81807ebea"}, + {file = "rpds_py-0.22.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3445e07bf2e8ecfeef6ef67ac83de670358abf2996916039b16a218e3d95e97e"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b2513ba235829860b13faa931f3b6846548021846ac808455301c23a101689d"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eaf16ae9ae519a0e237a0f528fd9f0197b9bb70f40263ee57ae53c2b8d48aeb3"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:583f6a1993ca3369e0f80ba99d796d8e6b1a3a2a442dd4e1a79e652116413091"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4617e1915a539a0d9a9567795023de41a87106522ff83fbfaf1f6baf8e85437e"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c150c7a61ed4a4f4955a96626574e9baf1adf772c2fb61ef6a5027e52803543"}, + {file = "rpds_py-0.22.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2fa4331c200c2521512595253f5bb70858b90f750d39b8cbfd67465f8d1b596d"}, + {file = "rpds_py-0.22.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:214b7a953d73b5e87f0ebece4a32a5bd83c60a3ecc9d4ec8f1dca968a2d91e99"}, + {file = "rpds_py-0.22.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f47ad3d5f3258bd7058d2d506852217865afefe6153a36eb4b6928758041d831"}, + {file = "rpds_py-0.22.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f276b245347e6e36526cbd4a266a417796fc531ddf391e43574cf6466c492520"}, + {file = "rpds_py-0.22.3-cp39-cp39-win32.whl", hash = "sha256:bbb232860e3d03d544bc03ac57855cd82ddf19c7a07651a7c0fdb95e9efea8b9"}, + {file = "rpds_py-0.22.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfbc454a2880389dbb9b5b398e50d439e2e58669160f27b60e5eca11f68ae17c"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:d48424e39c2611ee1b84ad0f44fb3b2b53d473e65de061e3f460fc0be5f1939d"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:24e8abb5878e250f2eb0d7859a8e561846f98910326d06c0d51381fed59357bd"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b232061ca880db21fa14defe219840ad9b74b6158adb52ddf0e87bead9e8493"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac0a03221cdb5058ce0167ecc92a8c89e8d0decdc9e99a2ec23380793c4dcb96"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb0c341fa71df5a4595f9501df4ac5abfb5a09580081dffbd1ddd4654e6e9123"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf9db5488121b596dbfc6718c76092fda77b703c1f7533a226a5a9f65248f8ad"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8db6b5b2d4491ad5b6bdc2bc7c017eec108acbf4e6785f42a9eb0ba234f4c9"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b3d504047aba448d70cf6fa22e06cb09f7cbd761939fdd47604f5e007675c24e"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:e61b02c3f7a1e0b75e20c3978f7135fd13cb6cf551bf4a6d29b999a88830a338"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e35ba67d65d49080e8e5a1dd40101fccdd9798adb9b050ff670b7d74fa41c566"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:26fd7cac7dd51011a245f29a2cc6489c4608b5a8ce8d75661bb4a1066c52dfbe"}, + {file = "rpds_py-0.22.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bb47271f60660803ad11f4c61b42242b8c1312a31c98c578f79ef9387bbde21c"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:70fb28128acbfd264eda9bf47015537ba3fe86e40d046eb2963d75024be4d055"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44d61b4b7d0c2c9ac019c314e52d7cbda0ae31078aabd0f22e583af3e0d79723"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f0e260eaf54380380ac3808aa4ebe2d8ca28b9087cf411649f96bad6900c728"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b25bc607423935079e05619d7de556c91fb6adeae9d5f80868dde3468657994b"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fb6116dfb8d1925cbdb52595560584db42a7f664617a1f7d7f6e32f138cdf37d"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a63cbdd98acef6570c62b92a1e43266f9e8b21e699c363c0fef13bd530799c11"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2b8f60e1b739a74bab7e01fcbe3dddd4657ec685caa04681df9d562ef15b625f"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:2e8b55d8517a2fda8d95cb45d62a5a8bbf9dd0ad39c5b25c8833efea07b880ca"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:2de29005e11637e7a2361fa151f780ff8eb2543a0da1413bb951e9f14b699ef3"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:666ecce376999bf619756a24ce15bb14c5bfaf04bf00abc7e663ce17c3f34fe7"}, + {file = "rpds_py-0.22.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:5246b14ca64a8675e0a7161f7af68fe3e910e6b90542b4bfb5439ba752191df6"}, + {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, ] [[package]] @@ -8553,29 +8659,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.8.1" +version = "0.8.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5"}, - {file = "ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087"}, - {file = "ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1"}, - {file = "ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c"}, - {file = "ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa"}, - {file = "ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540"}, - {file = "ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9"}, - {file = "ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5"}, - {file = "ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790"}, - {file = "ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6"}, - {file = "ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737"}, - {file = "ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f"}, + {file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"}, + {file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"}, + {file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"}, + {file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"}, + {file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"}, + {file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"}, + {file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"}, + {file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"}, + {file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"}, + {file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"}, + {file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"}, + {file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"}, ] [[package]] @@ -9074,13 +9180,13 @@ docs = ["sphinx"] [[package]] name = "six" -version = "1.16.0" +version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, ] [[package]] @@ -9216,13 +9322,13 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlparse" -version = "0.5.2" +version = "0.5.3" description = "A non-validating SQL parser." optional = false python-versions = ">=3.8" files = [ - {file = "sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e"}, - {file = "sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f"}, + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, ] [package.extras] @@ -9384,13 +9490,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1275" +version = "3.0.1277" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1275.tar.gz", hash = "sha256:81ad21abfe142973f25b9601af812587fd7f028f25ea5aea19d13d397e0d1469"}, - {file = "tencentcloud_sdk_python_common-3.0.1275-py2.py3-none-any.whl", hash = "sha256:3bcf5ea373cf17efe2c312717afffe3dd2fb070d21bf0b289609a0e21fd45889"}, + {file = "tencentcloud-sdk-python-common-3.0.1277.tar.gz", hash = "sha256:6cbdd664a7e764588b7ce609b95f9842d695d4adf7bc41062d2c44b96635e05d"}, + {file = "tencentcloud_sdk_python_common-3.0.1277-py2.py3-none-any.whl", hash = "sha256:14a7c7da997f8a565fae23ad3e94416fa7f63613b052070135f6bea3e3a3bc95"}, ] [package.dependencies] @@ -9398,17 +9504,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1275" +version = "3.0.1277" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1275.tar.gz", hash = "sha256:15804b6f0e686e516ffbb39fd87200559189feddd12e7f1866cdd59c616294f2"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1275-py2.py3-none-any.whl", hash = "sha256:97aa7b3af42fdbab001ecbc87f69b7215c67983d9fbac40a0bcc06a762a01132"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1277.tar.gz", hash = "sha256:0df70b21f9affa1d6139f006abb4cd56ced07083e4306d7d8272080566715db3"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1277-py2.py3-none-any.whl", hash = "sha256:2fef7233327fbf7bd2da987184d9dd731968aae0b7b6f2b9f177b0730b4e181f"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1275" +tencentcloud-sdk-python-common = "3.0.1277" [[package]] name = "termcolor" @@ -9793,13 +9899,13 @@ requests = ">=2.0.0" [[package]] name = "typer" -version = "0.15.0" +version = "0.15.1" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.15.0-py3-none-any.whl", hash = "sha256:bd16241db7e0f989ce1a0d8faa5aa1e43b9b9ac3fd1d4b8bcff91503d6717e38"}, - {file = "typer-0.15.0.tar.gz", hash = "sha256:8995452a598922ed8d8ad8c06ca63a218881ab601f5fa6fb0c511f7776497c7e"}, + {file = "typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847"}, + {file = "typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a"}, ] [package.dependencies] @@ -9808,6 +9914,17 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" +[[package]] +name = "types-pytz" +version = "2024.2.0.20241003" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"}, + {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"}, +] + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -9948,13 +10065,13 @@ files = [ [[package]] name = "unstructured" -version = "0.16.9" +version = "0.16.10" description = "A library that prepares raw documents for downstream ML tasks." optional = false python-versions = "<3.13,>=3.9.0" files = [ - {file = "unstructured-0.16.9-py3-none-any.whl", hash = "sha256:246e44dc99e7913677b9bb274782a7d61f2e2682581106c346b6daf969bbaaa0"}, - {file = "unstructured-0.16.9.tar.gz", hash = "sha256:30b47d5baf2a4eaa993c75812fa947c9fea870000eb82473a216829aa1d407d5"}, + {file = "unstructured-0.16.10-py3-none-any.whl", hash = "sha256:738fc020fb4d9dfd1a3e54fee255221f7f916afafa16ff4e1a7a14495ba5b5ce"}, + {file = "unstructured-0.16.10.tar.gz", hash = "sha256:61c4a447514ab5d6f8629fde2da03833cf29e0bee26a1d3b901ac57d3b5d523a"}, ] [package.dependencies] @@ -10274,82 +10391,82 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic [[package]] name = "watchfiles" -version = "1.0.0" +version = "1.0.3" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.9" files = [ - {file = "watchfiles-1.0.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1d19df28f99d6a81730658fbeb3ade8565ff687f95acb59665f11502b441be5f"}, - {file = "watchfiles-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28babb38cf2da8e170b706c4b84aa7e4528a6fa4f3ee55d7a0866456a1662041"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ab123135b2f42517f04e720526d41448667ae8249e651385afb5cda31fedc0"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13a4f9ee0cd25682679eea5c14fc629e2eaa79aab74d963bc4e21f43b8ea1877"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e1d9284cc84de7855fcf83472e51d32daf6f6cecd094160192628bc3fee1b78"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ee5edc939f53466b329bbf2e58333a5461e6c7b50c980fa6117439e2c18b42d"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dccfc70480087567720e4e36ec381bba1ed68d7e5f368fe40c93b3b1eba0105"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c83a6d33a9eda0af6a7470240d1af487807adc269704fe76a4972dd982d16236"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:905f69aad276639eff3893759a07d44ea99560e67a1cf46ff389cd62f88872a2"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09551237645d6bff3972592f2aa5424df9290e7a2e15d63c5f47c48cde585935"}, - {file = "watchfiles-1.0.0-cp310-none-win32.whl", hash = "sha256:d2b39aa8edd9e5f56f99a2a2740a251dc58515398e9ed5a4b3e5ff2827060755"}, - {file = "watchfiles-1.0.0-cp310-none-win_amd64.whl", hash = "sha256:2de52b499e1ab037f1a87cb8ebcb04a819bf087b1015a4cf6dcf8af3c2a2613e"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fbd0ab7a9943bbddb87cbc2bf2f09317e74c77dc55b1f5657f81d04666c25269"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:774ef36b16b7198669ce655d4f75b4c3d370e7f1cbdfb997fb10ee98717e2058"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b4fb98100267e6a5ebaff6aaa5d20aea20240584647470be39fe4823012ac96"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0fc3bf0effa2d8075b70badfdd7fb839d7aa9cea650d17886982840d71fdeabf"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:648e2b6db53eca6ef31245805cd528a16f56fa4cc15aeec97795eaf713c11435"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa13d604fcb9417ae5f2e3de676e66aa97427d888e83662ad205bed35a313176"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:936f362e7ff28311b16f0b97ec51e8f2cc451763a3264640c6ed40fb252d1ee4"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245fab124b9faf58430da547512d91734858df13f2ddd48ecfa5e493455ffccb"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4ff9c7e84e8b644a8f985c42bcc81457240316f900fc72769aaedec9d088055a"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c9a8d8fd97defe935ef8dd53d562e68942ad65067cd1c54d6ed8a088b1d931d"}, - {file = "watchfiles-1.0.0-cp311-none-win32.whl", hash = "sha256:a0abf173975eb9dd17bb14c191ee79999e650997cc644562f91df06060610e62"}, - {file = "watchfiles-1.0.0-cp311-none-win_amd64.whl", hash = "sha256:2a825ba4b32c214e3855b536eb1a1f7b006511d8e64b8215aac06eb680642d84"}, - {file = "watchfiles-1.0.0-cp311-none-win_arm64.whl", hash = "sha256:a5a7a06cfc65e34fd0a765a7623c5ba14707a0870703888e51d3d67107589817"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:28fb64b5843d94e2c2483f7b024a1280662a44409bedee8f2f51439767e2d107"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3750434c83b61abb3163b49c64b04180b85b4dabb29a294513faec57f2ffdb7"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bedf84835069f51c7b026b3ca04e2e747ea8ed0a77c72006172c72d28c9f69fc"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90004553be36427c3d06ec75b804233f8f816374165d5225b93abd94ba6e7234"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b46e15c34d4e401e976d6949ad3a74d244600d5c4b88c827a3fdf18691a46359"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:487d15927f1b0bd24e7df921913399bb1ab94424c386bea8b267754d698f8f0e"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ff236d7a3f4b0a42f699a22fc374ba526bc55048a70cbb299661158e1bb5e1f"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c01446626574561756067f00b37e6b09c8622b0fc1e9fdbc7cbcea328d4e514"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b551c465a59596f3d08170bd7e1c532c7260dd90ed8135778038e13c5d48aa81"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1ed613ee107269f66c2df631ec0fc8efddacface85314d392a4131abe299f00"}, - {file = "watchfiles-1.0.0-cp312-none-win32.whl", hash = "sha256:5f75cd42e7e2254117cf37ff0e68c5b3f36c14543756b2da621408349bd9ca7c"}, - {file = "watchfiles-1.0.0-cp312-none-win_amd64.whl", hash = "sha256:cf517701a4a872417f4e02a136e929537743461f9ec6cdb8184d9a04f4843545"}, - {file = "watchfiles-1.0.0-cp312-none-win_arm64.whl", hash = "sha256:8a2127cd68950787ee36753e6d401c8ea368f73beaeb8e54df5516a06d1ecd82"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:95de85c254f7fe8cbdf104731f7f87f7f73ae229493bebca3722583160e6b152"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:533a7cbfe700e09780bb31c06189e39c65f06c7f447326fee707fd02f9a6e945"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2218e78e2c6c07b1634a550095ac2a429026b2d5cbcd49a594f893f2bb8c936"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9122b8fdadc5b341315d255ab51d04893f417df4e6c1743b0aac8bf34e96e025"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9272fdbc0e9870dac3b505bce1466d386b4d8d6d2bacf405e603108d50446940"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3b33c3aefe9067ebd87846806cd5fc0b017ab70d628aaff077ab9abf4d06b3"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc338ce9f8846543d428260fa0f9a716626963148edc937d71055d01d81e1525"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ac778a460ea22d63c7e6fb0bc0f5b16780ff0b128f7f06e57aaec63bd339285"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53ae447f06f8f29f5ab40140f19abdab822387a7c426a369eb42184b021e97eb"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1f73c2147a453315d672c1ad907abe6d40324e34a185b51e15624bc793f93cc6"}, - {file = "watchfiles-1.0.0-cp313-none-win32.whl", hash = "sha256:eba98901a2eab909dbd79681190b9049acc650f6111fde1845484a4450761e98"}, - {file = "watchfiles-1.0.0-cp313-none-win_amd64.whl", hash = "sha256:d562a6114ddafb09c33246c6ace7effa71ca4b6a2324a47f4b09b6445ea78941"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3d94fd83ed54266d789f287472269c0def9120a2022674990bd24ad989ebd7a0"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48051d1c504448b2fcda71c5e6e3610ae45de6a0b8f5a43b961f250be4bdf5a8"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29cf884ad4285d23453c702ed03d689f9c0e865e3c85d20846d800d4787de00f"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d3572d4c34c4e9c33d25b3da47d9570d5122f8433b9ac6519dca49c2740d23cd"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c2696611182c85eb0e755b62b456f48debff484b7306b56f05478b843ca8ece"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:550109001920a993a4383b57229c717fa73627d2a4e8fcb7ed33c7f1cddb0c85"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b555a93c15bd2c71081922be746291d776d47521a00703163e5fbe6d2a402399"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947ccba18a38b85c366dafeac8df2f6176342d5992ca240a9d62588b214d731f"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ffd98a299b0a74d1b704ef0ed959efb753e656a4e0425c14e46ae4c3cbdd2919"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f8c4f3a1210ed099a99e6a710df4ff2f8069411059ffe30fa5f9467ebed1256b"}, - {file = "watchfiles-1.0.0-cp39-none-win32.whl", hash = "sha256:1e176b6b4119b3f369b2b4e003d53a226295ee862c0962e3afd5a1c15680b4e3"}, - {file = "watchfiles-1.0.0-cp39-none-win_amd64.whl", hash = "sha256:2d9c0518fabf4a3f373b0a94bb9e4ea7a1df18dec45e26a4d182aa8918dee855"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f159ac795785cde4899e0afa539f4c723fb5dd336ce5605bc909d34edd00b79b"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c3d258d78341d5d54c0c804a5b7faa66cd30ba50b2756a7161db07ce15363b8d"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbd0311588c2de7f9ea5cf3922ccacfd0ec0c1922870a2be503cc7df1ca8be7"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a13ac46b545a7d0d50f7641eefe47d1597e7d1783a5d89e09d080e6dff44b0"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2bca898c1dc073912d3db7fa6926cc08be9575add9e84872de2c99c688bac4e"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d828fe2adc4ac8a64b875ca908b892a3603d596d43e18f7948f3fef5fc671c"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074c7618cd6c807dc4eaa0982b4a9d3f8051cd0b72793511848fd64630174b17"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95dc785bc284552d044e561b8f4fe26d01ab5ca40d35852a6572d542adfeb4bc"}, - {file = "watchfiles-1.0.0.tar.gz", hash = "sha256:37566c844c9ce3b5deb964fe1a23378e575e74b114618d211fbda8f59d7b5dab"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1da46bb1eefb5a37a8fb6fd52ad5d14822d67c498d99bda8754222396164ae42"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2b961b86cd3973f5822826017cad7f5a75795168cb645c3a6b30c349094e02e3"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34e87c7b3464d02af87f1059fedda5484e43b153ef519e4085fe1a03dd94801e"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9dd2b89a16cf7ab9c1170b5863e68de6bf83db51544875b25a5f05a7269e678"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b4691234d31686dca133c920f94e478b548a8e7c750f28dbbc2e4333e0d3da9"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:90b0fe1fcea9bd6e3084b44875e179b4adcc4057a3b81402658d0eb58c98edf8"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b90651b4cf9e158d01faa0833b073e2e37719264bcee3eac49fc3c74e7d304b"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2e9fe695ff151b42ab06501820f40d01310fbd58ba24da8923ace79cf6d702d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62691f1c0894b001c7cde1195c03b7801aaa794a837bd6eef24da87d1542838d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:275c1b0e942d335fccb6014d79267d1b9fa45b5ac0639c297f1e856f2f532552"}, + {file = "watchfiles-1.0.3-cp310-cp310-win32.whl", hash = "sha256:06ce08549e49ba69ccc36fc5659a3d0ff4e3a07d542b895b8a9013fcab46c2dc"}, + {file = "watchfiles-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f280b02827adc9d87f764972fbeb701cf5611f80b619c20568e1982a277d6146"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ffe709b1d0bc2e9921257569675674cafb3a5f8af689ab9f3f2b3f88775b960f"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:418c5ce332f74939ff60691e5293e27c206c8164ce2b8ce0d9abf013003fb7fe"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f492d2907263d6d0d52f897a68647195bc093dafed14508a8d6817973586b6b"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48c9f3bc90c556a854f4cab6a79c16974099ccfa3e3e150673d82d47a4bc92c9"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75d3bcfa90454dba8df12adc86b13b6d85fda97d90e708efc036c2760cc6ba44"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5691340f259b8f76b45fb31b98e594d46c36d1dc8285efa7975f7f50230c9093"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e263cc718545b7f897baeac1f00299ab6fabe3e18caaacacb0edf6d5f35513c"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6cf7709ed3e55704cc06f6e835bf43c03bc8e3cb8ff946bf69a2e0a78d9d77"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:703aa5e50e465be901e0e0f9d5739add15e696d8c26c53bc6fc00eb65d7b9469"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bfcae6aecd9e0cb425f5145afee871465b98b75862e038d42fe91fd753ddd780"}, + {file = "watchfiles-1.0.3-cp311-cp311-win32.whl", hash = "sha256:6a76494d2c5311584f22416c5a87c1e2cb954ff9b5f0988027bc4ef2a8a67181"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:cf745cbfad6389c0e331786e5fe9ae3f06e9d9c2ce2432378e1267954793975c"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:2dcc3f60c445f8ce14156854a072ceb36b83807ed803d37fdea2a50e898635d6"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:93436ed550e429da007fbafb723e0769f25bae178fbb287a94cb4ccdf42d3af3"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c18f3502ad0737813c7dad70e3e1cc966cc147fbaeef47a09463bbffe70b0a00"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a5bc3ca468bb58a2ef50441f953e1f77b9a61bd1b8c347c8223403dc9b4ac9a"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0d1ec043f02ca04bf21b1b32cab155ce90c651aaf5540db8eb8ad7f7e645cba8"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f58d3bfafecf3d81c15d99fc0ecf4319e80ac712c77cf0ce2661c8cf8bf84066"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1df924ba82ae9e77340101c28d56cbaff2c991bd6fe8444a545d24075abb0a87"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:632a52dcaee44792d0965c17bdfe5dc0edad5b86d6a29e53d6ad4bf92dc0ff49"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bf4b459d94a0387617a1b499f314aa04d8a64b7a0747d15d425b8c8b151da0"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca94c85911601b097d53caeeec30201736ad69a93f30d15672b967558df02885"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65ab1fb635476f6170b07e8e21db0424de94877e4b76b7feabfe11f9a5fc12b5"}, + {file = "watchfiles-1.0.3-cp312-cp312-win32.whl", hash = "sha256:49bc1bc26abf4f32e132652f4b3bfeec77d8f8f62f57652703ef127e85a3e38d"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:48681c86f2cb08348631fed788a116c89c787fdf1e6381c5febafd782f6c3b44"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:9e080cf917b35b20c889225a13f290f2716748362f6071b859b60b8847a6aa43"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e153a690b7255c5ced17895394b4f109d5dcc2a4f35cb809374da50f0e5c456a"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac1be85fe43b4bf9a251978ce5c3bb30e1ada9784290441f5423a28633a958a7"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec98e31e1844eac860e70d9247db9d75440fc8f5f679c37d01914568d18721"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0179252846be03fa97d4d5f8233d1c620ef004855f0717712ae1c558f1974a16"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:995c374e86fa82126c03c5b4630c4e312327ecfe27761accb25b5e1d7ab50ec8"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29b9cb35b7f290db1c31fb2fdf8fc6d3730cfa4bca4b49761083307f441cac5a"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f8dc09ae69af50bead60783180f656ad96bd33ffbf6e7a6fce900f6d53b08f1"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:489b80812f52a8d8c7b0d10f0d956db0efed25df2821c7a934f6143f76938bd6"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:228e2247de583475d4cebf6b9af5dc9918abb99d1ef5ee737155bb39fb33f3c0"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1550be1a5cb3be08a3fb84636eaafa9b7119b70c71b0bed48726fd1d5aa9b868"}, + {file = "watchfiles-1.0.3-cp313-cp313-win32.whl", hash = "sha256:16db2d7e12f94818cbf16d4c8938e4d8aaecee23826344addfaaa671a1527b07"}, + {file = "watchfiles-1.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:160eff7d1267d7b025e983ca8460e8cc67b328284967cbe29c05f3c3163711a3"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c05b021f7b5aa333124f2a64d56e4cb9963b6efdf44e8d819152237bbd93ba15"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:310505ad305e30cb6c5f55945858cdbe0eb297fc57378f29bacceb534ac34199"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddff3f8b9fa24a60527c137c852d0d9a7da2a02cf2151650029fdc97c852c974"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46e86ed457c3486080a72bc837300dd200e18d08183f12b6ca63475ab64ed651"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f79fe7993e230a12172ce7d7c7db061f046f672f2b946431c81aff8f60b2758b"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea2b51c5f38bad812da2ec0cd7eec09d25f521a8b6b6843cbccedd9a1d8a5c15"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fe4e740ea94978b2b2ab308cbf9270a246bcbb44401f77cc8740348cbaeac3d"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9af037d3df7188ae21dc1c7624501f2f90d81be6550904e07869d8d0e6766655"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52bb50a4c4ca2a689fdba84ba8ecc6a4e6210f03b6af93181bb61c4ec3abaf86"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c14a07bdb475eb696f85c715dbd0f037918ccbb5248290448488a0b4ef201aad"}, + {file = "watchfiles-1.0.3-cp39-cp39-win32.whl", hash = "sha256:be37f9b1f8934cd9e7eccfcb5612af9fb728fecbe16248b082b709a9d1b348bf"}, + {file = "watchfiles-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ef9ec8068cf23458dbf36a08e0c16f0a2df04b42a8827619646637be1769300a"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:84fac88278f42d61c519a6c75fb5296fd56710b05bbdcc74bdf85db409a03780"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c68be72b1666d93b266714f2d4092d78dc53bd11cf91ed5a3c16527587a52e29"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889a37e2acf43c377b5124166bece139b4c731b61492ab22e64d371cce0e6e80"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca05cacf2e5c4a97d02a2878a24020daca21dbb8823b023b978210a75c79098"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8af4b582d5fc1b8465d1d2483e5e7b880cc1a4e99f6ff65c23d64d070867ac58"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:127de3883bdb29dbd3b21f63126bb8fa6e773b74eaef46521025a9ce390e1073"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713f67132346bdcb4c12df185c30cf04bdf4bf6ea3acbc3ace0912cab6b7cb8c"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abd85de513eb83f5ec153a802348e7a5baa4588b818043848247e3e8986094e8"}, + {file = "watchfiles-1.0.3.tar.gz", hash = "sha256:f3ff7da165c99a5412fe5dd2304dd2dbaaaa5da718aad942dcb3a178eaa70c56"}, ] [package.dependencies] @@ -10527,13 +10644,13 @@ requests = ">=2.0.0,<3.0.0" [[package]] name = "win32-setctime" -version = "1.1.0" +version = "1.2.0" description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" files = [ - {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, - {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, + {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, + {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, ] [package.extras] @@ -10834,13 +10951,13 @@ requests = "*" [[package]] name = "zhipuai" -version = "2.1.5.20241203" +version = "2.1.5.20241204" description = "A SDK library for accessing big model apis from ZhipuAI" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "zhipuai-2.1.5.20241203-py3-none-any.whl", hash = "sha256:77267aebbb7dabbff1d0706c4fc1d529feb17959613d1b130ba58a733548c21c"}, - {file = "zhipuai-2.1.5.20241203.tar.gz", hash = "sha256:4096a467cb3f43c4eb63e6e19564d2347624ceaf89a529b9e849fff0935f3da2"}, + {file = "zhipuai-2.1.5.20241204-py3-none-any.whl", hash = "sha256:063c7527d6741ced82eedb19d53fd24ce61cf43ab835ee3c0262843f59503a7c"}, + {file = "zhipuai-2.1.5.20241204.tar.gz", hash = "sha256:888b42a83c8f1daf07375b84e560219eedab96b9f9e59542f0329928291db635"}, ] [package.dependencies] @@ -11056,4 +11173,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "03d62501ae48efc47f3f35dbea7e66ccd1fbcebe69e263f6396d00e3803f2114" +content-hash = "f4accd01805cbf080c4c5295f97a06c8e4faec7365d2c43d0435e56b46461732" diff --git a/api/pyproject.toml b/api/pyproject.toml index e3820ecf9afcb1..28e0305406a18b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -45,7 +45,7 @@ google-auth-httplib2 = "0.2.0" google-cloud-aiplatform = "1.49.0" google-generativeai = "0.8.1" googleapis-common-protos = "1.63.0" -gunicorn = "~22.0.0" +gunicorn = "~23.0.0" httpx = { version = "~0.27.0", extras = ["socks"] } huggingface-hub = "~0.16.4" jieba = "0.42.1" @@ -60,13 +60,14 @@ oci = "~2.135.1" openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } +pandas-stubs = "~2.2.3.241009" psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.9.2" pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" -pypdfium2 = "~4.17.0" +pypdfium2 = "~4.30.0" python = ">=3.11,<3.13" python-docx = "~1.1.0" python-dotenv = "1.0.0" @@ -84,6 +85,7 @@ tencentcloud-sdk-python-hunyuan = "~3.0.1158" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" +types-pytz = "~2024.2.0.20241003" unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} @@ -134,6 +136,7 @@ bce-python-sdk = "~0.9.23" cos-python-sdk-v5 = "1.9.30" esdk-obs-python = "3.24.6.1" google-cloud-storage = "2.16.0" +opendal = "~0.45.12" oss2 = "2.18.5" supabase = "~2.8.1" tos = "~2.7.1" @@ -172,6 +175,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" faker = "~32.1.0" +mypy = "~1.13.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 72ee2a8901af79..48bdc872f41e5c 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -32,18 +32,21 @@ def clean_messages(): while True: try: # Main query with join and filter + # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) + db.session.query(Message) # type: ignore .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) - .paginate(page=page, per_page=100) + .limit(100) + .all() ) except NotFound: break - if messages.items is None or len(messages.items) == 0: + if not messages: break - for message in messages.items: + for message in messages: + plan_sandbox_clean_message_day = message.created_at app = App.query.filter_by(id=message.app_id).first() features_cache_key = f"features:{app.tenant_id}" plan_cache = redis_client.get(features_cache_key) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index e12be649e4d02d..f66b3c47979435 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -52,8 +52,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -120,8 +119,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index a20b500308a4d6..1c985461c6aa2e 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -36,14 +36,15 @@ def create_tidb_serverless_task(): def create_clusters(batch_size): try: + # TODO: maybe we can set the default value for the following parameters in the config file new_clusters = TidbService.batch_create_tidb_serverless_cluster( - batch_size, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index b2d8746f9ca8f4..11a39e60ee4ce5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): # batch 20 for i in range(0, len(tidb_serverless_list), 20): items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file TidbService.batch_update_tidb_serverless_cluster_status( - items, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", ) except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index f0c6ac7ebd622b..2d37db391c899c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,7 +6,7 @@ import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel from sqlalchemy import func @@ -119,7 +119,7 @@ def load_user(user_id: str) -> None | Account: account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return account + return cast(Account, account) @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -132,7 +132,7 @@ def get_account_jwt_token(account: Account) -> str: "sub": "Console API Passport", } - token = PassportService().issue(payload) + token: str = PassportService().issue(payload) return token @staticmethod @@ -164,7 +164,7 @@ def authenticate(email: str, password: str, invite_token: Optional[str] = None) db.session.commit() - return account + return cast(Account, account) @staticmethod def update_account_password(account, password, new_password): @@ -347,6 +347,8 @@ def send_reset_password_email( language: Optional[str] = "en-US", ): account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError @@ -377,6 +379,8 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): + if email is None: + raise ValueError("Email must be provided.") if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError @@ -420,7 +424,7 @@ def add_login_error_rate_limit(email: str) -> None: if count is None: count = 0 count = int(count) + 1 - redis_client.setex(key, 60 * 60 * 24, count) + redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count) @staticmethod def is_login_error_rate_limit(email: str) -> bool: @@ -669,7 +673,7 @@ def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoi @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return db.session.query(func.count(Tenant.id)).scalar() + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: @@ -733,10 +737,10 @@ def dissolve_tenant(tenant: Tenant, operator: Account) -> None: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> None: - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404() + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() - return tenant.custom_config_dict + return cast(dict, tenant.custom_config_dict) class RegisterService: @@ -793,6 +797,7 @@ def register( language: Optional[str] = None, status: Optional[AccountStatus] = None, is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -807,10 +812,10 @@ def register( account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None or provider is not None: + if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace: + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant @@ -828,10 +833,11 @@ def register( @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() + assert inviter is not None, "Inviter must be provided." if not account: TenantService.check_member_permission(tenant, inviter, None, "add") @@ -894,7 +900,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str): redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -953,7 +961,7 @@ def _get_invitation_by_token( if not data: return None - invitation = json.loads(data) + invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d2cd7bea67c5b6..6dc1affa11d036 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -48,6 +48,8 @@ def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> return cls.get_chat_prompt( copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt ) + # default return empty dict + return {} @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: @@ -91,3 +93,5 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) - return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py index c8819535f11a39..b02f762ad267b8 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,7 @@ +from typing import Optional + import pytz -from flask_login import current_user +from flask_login import current_user # type: ignore from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.tools.tool_manager import ToolManager @@ -14,7 +16,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - """ Service to get agent logs """ - conversation: Conversation = ( + conversation: Optional[Conversation] = ( db.session.query(Conversation) .filter( Conversation.id == conversation_id, @@ -26,7 +28,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = ( + message: Optional[Message] = ( db.session.query(Message) .filter( Message.id == message_id, @@ -72,7 +74,10 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) - agent_tools = agent_config.tools + if not agent_config: + return result + + agent_tools = agent_config.tools or [] def find_agent_tool(tool_name: str): for agent_tool in agent_tools: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f45c21cb18f5e3..a946405c955cec 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,8 +1,9 @@ import datetime import uuid +from typing import cast import pandas as pd -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import or_ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -71,7 +72,7 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa app_id, annotation_setting.collection_binding_id, ) - return annotation + return cast(MessageAnnotation, annotation) @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: @@ -124,8 +125,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo raise NotFound("App not found") if keyword: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -137,8 +137,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo ) else: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) @@ -327,8 +326,7 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim raise NotFound("Annotation not found") annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .filter( + AppAnnotationHitHistory.query.filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8180c3b400719a..b191fa2397fa9e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,7 +1,7 @@ import logging import uuid from enum import StrEnum -from typing import Optional +from typing import Optional, cast from uuid import uuid4 import yaml @@ -22,7 +22,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes -CURRENT_DSL_VERSION = "0.1.4" +CURRENT_DSL_VERSION = "0.1.5" class ImportMode(StrEnum): @@ -103,7 +103,7 @@ def import_app( raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content - content = "" + content: bytes | str = b"" if mode == ImportMode.YAML_URL: if not yaml_url: return Import( @@ -136,7 +136,7 @@ def import_app( ) try: - content = content.decode("utf-8") + content = cast(bytes, content).decode("utf-8") except UnicodeDecodeError as e: return Import( id=import_id, @@ -340,7 +340,10 @@ def _create_or_update_app( ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) - app_mode = AppMode(app_data["mode"]) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) # Set icon type icon_type_value = icon_type or app_data.get("icon_type") @@ -359,6 +362,9 @@ def _create_or_update_app( app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + # Create new app app = App() app.id = str(uuid4()) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 545de8190cb51e..51aef7ccab9a0c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -108,6 +108,9 @@ def generate( raise ValueError(f"Invalid app mode {app_model.mode}") except RateLimitError as e: raise InvokeRateLimitError(str(e)) + except Exception: + rate_limit.exit(request_id) + raise finally: if not streaming: rate_limit.exit(request_id) @@ -115,7 +118,7 @@ def generate( @staticmethod def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests - if app_model.max_active_requests is None: + if max_active_requests is None: max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests @@ -147,7 +150,7 @@ def generate_more_like_this( message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator]: + ) -> Union[Mapping, Generator]: """ Generate more like this :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 8d8ba735ecfa71..41c15bbf0a330b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,9 @@ import json import logging from datetime import UTC, datetime -from typing import cast +from typing import Optional, cast -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -83,7 +83,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -100,6 +100,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") default_model_dict = { "provider": model_instance.provider, @@ -109,7 +111,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model @@ -314,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = {"tool_icons": {}} + meta: dict = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -336,7 +338,7 @@ def get_app_meta(self, app_model: App) -> dict: } ) else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config: Optional[AppModelConfig] = app_model.app_model_config if not app_model_config: return meta @@ -352,16 +354,18 @@ def get_app_meta(self, app_model: App) -> dict: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type") - provider_id = tool.get("provider_id") - tool_name = tool.get("tool_name") + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a0cd5725b2a96..973110f5156523 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -110,6 +110,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): voices = model_instance.get_tts_voices() if voices: voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") @@ -121,6 +123,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): if message_id: message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None if message.answer == "" and message.status == "normal": return None @@ -130,6 +134,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): return Response(stream_with_context(response), content_type="audio/mpeg") return response else: + if not text: + raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): return Response(stream_with_context(response), content_type="audio/mpeg") diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index afc491398f25f3..50e4edff140346 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -11,8 +11,8 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) - self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 911d2346415ce5..d98018648839a9 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,6 +1,8 @@ import os +from typing import Optional -import requests +import httpx +from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -39,11 +41,17 @@ def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): return cls._send_request("GET", "/invoices", params=params) @classmethod + @retry( + wait=wait_fixed(2), + stop=stop_before_delay(10), + retry=retry_if_not_exception_type(httpx.RequestError), + reraise=True, + ) def _send_request(cls, method, endpoint, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers) return response.json() @@ -51,11 +59,14 @@ def _send_request(cls, method, endpoint, json=None, params=None): def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = ( + join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) + if not join: + raise ValueError("Tenant account join not found") + if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 8642972710fd1f..6485cbf37d5b7f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import UTC, datetime from typing import Optional, Union -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,19 +19,21 @@ class ConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Conversation).filter( + stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -38,37 +41,39 @@ def pagination_by_last_id( Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - if include_ids is not None: - base_query = base_query.filter(Conversation.id.in_(include_ids)) - + stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: - base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) if last_id: - last_conversation = base_query.filter(Conversation.id == last_id).first() + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() # build filters based on sorting - filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) - base_query = base_query.filter(filter_condition) - - base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) - - conversations = base_query.limit(limit).all() + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] rest_filter_condition = cls._build_filter_condition( - sort_field, sort_direction, current_page_last_conversation, is_next_page=True + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, ) - rest_count = base_query.filter(rest_filter_condition).count() - + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) + rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True @@ -81,11 +86,9 @@ def _get_sort_params(cls, sort_by: str): return sort_by, asc @classmethod - def _build_filter_condition( - cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False - ): + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + if sort_direction == desc: return getattr(Conversation, sort_field) < field_value else: return getattr(Conversation, sort_field) > field_value diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index a1014e8e0ad908..d2d8a718d55c8a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -186,8 +186,9 @@ def create_empty_dataset( return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset: - return Dataset.query.filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset @staticmethod def check_dataset_model_setting(dataset): @@ -228,14 +229,20 @@ def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": - dataset.retrieval_model = data.get("external_retrieval_model", None) + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", "") + permission = data.get("permission") + if permission: + dataset.permission = permission external_knowledge_id = data.get("external_knowledge_id", None) - dataset.permission = data.get("permission") db.session.add(dataset) if not external_knowledge_id: raise ValueError("External knowledge id is required.") @@ -367,7 +374,13 @@ def check_dataset_permission(dataset, user): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") + + if not user: + raise ValueError("User not found") + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: raise NoPermissionError("You do not have permission to access this dataset.") @@ -761,6 +774,11 @@ def save_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" + ) + return db.session.add(dataset_process_rule) db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) @@ -1005,9 +1023,10 @@ def update_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.commit() - document.dataset_process_rule_id = dataset_process_rule.id + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id # update document data source if document_data.get("data_source"): file_name = "" @@ -1550,7 +1569,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change if segment_update_entity.keywords: segment.keywords = segment_update_entity.keywords @@ -1565,7 +1584,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document db.session.add(document) # update segment index task if segment_update_entity.enabled: - VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) + keywords = segment_update_entity.keywords or [] + VectorService.create_segments_vector([keywords], [segment], dataset) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1597,7 +1617,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1615,8 +1635,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.status = "error" segment.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() - return segment + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1676,6 +1696,8 @@ def get_dataset_collection_binding_by_id_and_type( .order_by(DatasetCollectionBinding.created_at) .first() ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") return dataset_collection_binding diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 92098f06cca538..3c3f9704440342 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -8,8 +8,8 @@ class EnterpriseRequest: secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") proxies = { - "http": None, - "https": None, + "http": "", + "https": "", } @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index c519f0b0e51b68..f1417c6cb94b80 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,7 +4,10 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config -from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.entities.model_entities import ( + ModelWithProviderEntity, + ProviderModelWithStatusEntity, +) from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -148,7 +151,8 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleProviderEntityResponse + # FIXME type error ignore here + provider: SimpleProviderEntityResponse # type: ignore def __init__(self, model: ModelWithProviderEntity) -> None: super().__init__(**model.model_dump()) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 7e3cd87f1eec2b..898624066bef7e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx import validators @@ -45,7 +45,10 @@ def validate_api_list(cls, api_settings: dict): @staticmethod def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: - ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) external_knowledge_api = ExternalKnowledgeApis( tenant_id=tenant_id, created_by=user_id, @@ -69,7 +72,10 @@ def check_endpoint_and_api_key(settings: dict): endpoint = f"{settings['endpoint']}/retrieval" api_key = settings["api_key"] if not validators.url(endpoint, simple_host=True): - raise ValueError(f"invalid endpoint: {endpoint}") + if not endpoint.startswith("http://") and not endpoint.startswith("https://"): + raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://") + else: + raise ValueError(f"invalid endpoint: {endpoint}") try: response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"}) except Exception as e: @@ -83,11 +89,16 @@ def check_endpoint_and_api_key(settings: dict): @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( id=external_knowledge_api_id, tenant_id=tenant_id ).first() if external_knowledge_api is None: @@ -124,7 +135,7 @@ def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bo @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() if not external_knowledge_binding: @@ -160,8 +171,9 @@ def process_external_api( "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) - + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) return response @staticmethod @@ -262,15 +274,15 @@ def fetch_external_knowledge_retrieval( "knowledge_id": external_knowledge_binding.external_knowledge_id, } - external_knowledge_api_setting = { - "url": f"{settings.get('endpoint')}/retrieval", - "request_method": "post", - "headers": headers, - "params": request_params, - } response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, ) if response.status_code == 200: - return response.json().get("records", []) + return cast(list[Any], response.json().get("records", [])) return [] diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 6bd82a27574274..0386c6aceaa2e7 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -63,6 +63,7 @@ class SystemFeatureModel(BaseModel): enable_social_oauth_login: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False + is_email_setup: bool = False license: LicenseModel = LicenseModel() @@ -98,6 +99,7 @@ def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): diff --git a/api/services/file_service.py b/api/services/file_service.py index b12b95ca13558c..d417e81734c8af 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,7 @@ import uuid from typing import Any, Literal, Union -from flask_login import current_user +from flask_login import current_user # type: ignore from werkzeug.exceptions import NotFound from configs import dify_config @@ -61,14 +61,14 @@ def upload_file( # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( - tenant_id=current_tenant_id, + tenant_id=current_tenant_id or "", storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7957b4dc82dfd4..41b4e1ec46374a 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,5 +1,6 @@ import logging import time +from typing import Any from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document @@ -24,7 +25,7 @@ def retrieve( dataset: Dataset, query: str, account: Account, - retrieval_model: dict, + retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, ) -> dict: @@ -68,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_retrieve_response(dataset, query, all_documents)) @classmethod def external_retrieve( @@ -102,13 +103,16 @@ def external_retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_external_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): records = [] for document in documents: + if document.metadata is None: + continue + index_node_id = document.metadata["doc_id"] segment = ( @@ -140,7 +144,7 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: records = [] if dataset.provider == "external": for document in documents: @@ -152,11 +156,10 @@ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, docume } records.append(record) return { - "query": { - "content": query, - }, + "query": {"content": query}, "records": records, } + return {"query": {"content": query}, "records": []} @classmethod def hit_testing_args_check(cls, args): diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc42be..8df1a6ba144d4e 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 +import boto3 # type: ignore from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index f432a77c80e511..c4447a84da5e09 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -151,8 +151,13 @@ def pagination_by_last_id( @classmethod def create_feedback( - cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] - ) -> MessageFeedback: + cls, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], + ): if not user: raise ValueError("user cannot be None") @@ -164,6 +169,7 @@ def create_feedback( db.session.delete(feedback) elif rating and feedback: feedback.rating = rating + feedback.content = content elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -172,6 +178,7 @@ def create_feedback( conversation_id=message.conversation_id, message_id=message.id, rating=rating, + content=content, from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), @@ -257,6 +264,8 @@ def get_suggested_questions_after_answer( ) app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: @@ -278,7 +287,7 @@ def get_suggested_questions_after_answer( ) with measure_time() as timer: - questions = LLMGenerator.generate_suggested_questions_after_answer( + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b20bda87551ca9..bacd3a8ec3d04f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,7 +2,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -88,11 +88,11 @@ def get_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( - model_type=model_type, + model_type=model_type_enum, model=model, ) @@ -106,7 +106,7 @@ def get_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .order_by(LoadBalancingModelConfig.created_at) @@ -124,7 +124,7 @@ def get_load_balancing_configs( if not inherit_config_exists: # Initialize the inherit configuration - inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) @@ -148,7 +148,7 @@ def get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model, - model_type=model_type, + model_type=model_type_enum, config_id=load_balancing_config.id, ) @@ -214,7 +214,7 @@ def get_load_balancing_config( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations load_balancing_model_config = ( @@ -222,7 +222,7 @@ def get_load_balancing_config( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -300,7 +300,7 @@ def update_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") @@ -310,7 +310,7 @@ def update_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .all() @@ -359,7 +359,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, @@ -395,7 +395,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, validate=False, @@ -405,7 +405,7 @@ def update_load_balancing_configs( load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type_enum.to_origin_model_type(), model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -450,7 +450,7 @@ def validate_load_balancing_credentials( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) load_balancing_model_config = None if config_id: @@ -460,7 +460,7 @@ def validate_load_balancing_credentials( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -474,7 +474,7 @@ def validate_load_balancing_credentials( self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, @@ -547,19 +547,14 @@ def _custom_credentials_validate( def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> ModelCredentialSchema | ProviderCredentialSchema: - """ - Get form schemas. - :param provider_configuration: provider configuration - :return: - """ - # Get credential form schemas from model credential schema or provider credential schema + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" if provider_configuration.provider.model_credential_schema: - credential_schema = provider_configuration.provider.model_credential_schema + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema else: - credential_schema = provider_configuration.provider.provider_credential_schema - - return credential_schema + raise ValueError("No credential schema found") def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 384a072b371fdd..b10c5ad2d616e9 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -7,7 +7,7 @@ import requests from flask import current_app -from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -100,23 +100,15 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: + def get_provider_credentials(self, tenant_id: str, provider: str): """ get provider credentials. - - :param tenant_id: - :param provider: - :return: """ - # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get provider custom credentials from workspace return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: @@ -176,7 +168,7 @@ def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: # Remove custom provider credentials. provider_configuration.delete_custom_credentials() - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict: + def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str): """ get model credentials. @@ -287,7 +279,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider - provider_models = {} + provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] @@ -362,7 +354,7 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules(model=model, credentials=credentials) + return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials)) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -422,6 +414,7 @@ def get_model_provider_icon( """ provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() + file_name: str | None = None if icon_type.lower() == "icon_small": if not provider_schema.icon_small: @@ -439,6 +432,8 @@ def get_model_provider_icon( file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US + if not file_name: + return None, None root_path = current_app.root_path provider_instance_path = os.path.dirname( @@ -524,7 +519,7 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/apply" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} @@ -545,7 +540,7 @@ def free_quota_submit(self, tenant_id: str, provider: str): def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/qualification-verify" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dfb21e767fc9b9..082afeed89a5e4 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db from models.model import App, AppModelConfig @@ -5,7 +7,7 @@ class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: AppModelConfig = None + app_model_config: Optional[AppModelConfig] = None app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1160a1f2751d74..fc1e08518b1945 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -12,7 +14,7 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -22,7 +24,10 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token and obfuscated_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -73,8 +78,9 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["other_keys"], ) - default_config_instance = config_class(**tracing_config) - for key in other_keys: + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -92,7 +98,7 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -102,7 +108,10 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if project_url: tracing_config["project_url"] = project_url @@ -139,7 +148,10 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 4704d533a950ed..523aebeed52a4e 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -41,7 +41,7 @@ def _get_builtin_data(cls) -> dict: Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data + return cls.builtin_data or {} @classmethod def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: @@ -50,8 +50,8 @@ def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: :param language: language :return: """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: @@ -60,5 +60,5 @@ def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict :param app_id: App ID :return: """ - builtin_data = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 995d3755bb5b10..3295516cce66f3 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -57,13 +57,7 @@ def fetch_recommended_apps_from_db(cls, language: str) -> dict: recommended_app_result = { "id": recommended_app.id, - "app": { - "id": app.id, - "name": app.name, - "mode": app.mode, - "icon": app.icon, - "icon_background": app.icon_background, - }, + "app": recommended_app.app, "app_id": recommended_app.app_id, "description": site.description, "copyright": site.copyright, diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b0607a21323acb..80e1aefc01da85 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -47,8 +47,8 @@ def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optiona response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None - - return response.json() + data: dict = response.json() + return data @classmethod def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -63,7 +63,7 @@ def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result = response.json() + result: dict = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4660316fcfcf71..54c58455155c03 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -33,5 +33,5 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommend_app_detail(app_id) + result: dict = retrieval_instance.get_recommend_app_detail(app_id) return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 9fe3cecce7546d..4cb8700117e6f3 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -13,6 +13,8 @@ class SavedMessageService: def pagination_by_last_id( cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) .filter( @@ -31,6 +33,8 @@ def pagination_by_last_id( @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( @@ -59,6 +63,8 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a374bdcf002bef..9600601633cddb 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,7 +1,7 @@ import uuid from typing import Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -21,7 +21,7 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = Non if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78a80f70ab6b00..0e3bd3a7b83c68 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,6 +1,7 @@ import json import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from httpx import get @@ -28,12 +29,12 @@ class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> list[ApiToolBundle]: + def parser_api_schema(schema: str) -> Mapping[str, Any]: """ parse api schema to tool bundle """ try: - warnings = {} + warnings: dict[str, str] = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: @@ -68,13 +69,16 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]: ), ] - return jsonable_encoder( - { - "schema_type": schema_type, - "parameters_schema": tool_bundles, - "credentials_schema": credentials_schema, - "warning": warnings, - } + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), ) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -129,7 +133,7 @@ def create_api_tool_provider( raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -262,9 +266,8 @@ def update_api_tool_provider( if provider is None: raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -416,7 +419,7 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime( + runtime_tool = tool.fork_tool_runtime( runtime={ "credentials": credentials, "tenant_id": tenant_id, @@ -454,7 +457,7 @@ def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - for tool in tools: + for tool in tools or []: user_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e2e49d017ef167..21adbb0074724e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,9 @@ import logging from pathlib import Path +from sqlalchemy import select +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,7 +35,7 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str tenant_id=tenant_id, provider_controller=provider_controller ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -47,8 +50,8 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - result = [] - for tool in tools: + result: list[UserTool] = [] + for tool in tools or []: result.append( ToolTransformService.tool_to_user_tool( tool=tool, @@ -71,19 +74,18 @@ def list_builtin_provider_credentials_schema(provider_name): return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, ) + provider = session.scalar(stmt) try: # get provider @@ -115,13 +117,10 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -129,15 +128,15 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ) .first() ) @@ -156,7 +155,7 @@ def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st """ delete tool provider """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -218,6 +217,8 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: name_func=lambda x: x.identity.name, ): continue + if provider_controller.identity is None: + continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( @@ -230,7 +231,7 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: ToolTransformService.repack_provider(user_builtin_provider) tools = provider_controller.get_tools() - for tool in tools: + for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index a4aa870dc80352..6e3a45be0da1c4 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union +from typing import Optional, Union, cast from configs import dify_config from core.tools.entities.api_entities import UserTool, UserToolProvider @@ -35,7 +35,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: - return json.loads(icon) + return cast(dict, json.loads(icon)) except: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -53,8 +53,11 @@ def repack_provider(provider: Union[dict, UserToolProvider]): provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): - provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + provider.icon = cast( + str, + ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + ), ) @staticmethod @@ -66,6 +69,9 @@ def builtin_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + result = UserToolProvider( id=provider_controller.identity.name, author=provider_controller.identity.author, @@ -93,7 +99,8 @@ def builtin_provider_to_user_provider( # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) + assert result.masked_credentials is not None, "masked credentials is None" + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type)) # check if the provider need credentials if not provider_controller.need_credentials: @@ -149,6 +156,9 @@ def workflow_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + return UserToolProvider( id=provider_controller.provider_id, author=provider_controller.identity.author, @@ -180,6 +190,8 @@ def api_provider_to_user_provider( convert provider controller to user provider """ username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") try: username = db_provider.user.name except Exception as e: @@ -256,19 +268,22 @@ def tool_to_user_tool( if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) + if tool.identity is None: + raise ValueError("tool identity is None") + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, - description=tool.description.human, + description=tool.description.human if tool.description else "", # type: ignore parameters=current_parameters, labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, - name=tool.operation_id, - label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + name=tool.operation_id or "", + label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""), description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, labels=labels, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 833881b668b383..69430de432b143 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -6,8 +6,10 @@ from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.api_entities import UserToolProvider +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -32,7 +34,7 @@ def create_workflow_tool( label: str, icon: dict, description: str, - parameters: Mapping[str, Any], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -81,6 +83,10 @@ def create_workflow_tool( db.session.add(workflow_tool_provider) db.session.commit() + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels + ) return {"result": "success"} @classmethod @@ -93,7 +99,7 @@ def update_workflow_tool( label: str, icon: dict, description: str, - parameters: list[dict], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -127,7 +133,7 @@ def update_workflow_tool( if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider = ( + workflow_tool_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -136,14 +142,14 @@ def update_workflow_tool( if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App = ( + app: Optional[App] = ( db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") - workflow: Workflow = app.workflow + workflow: Optional[Workflow] = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") @@ -189,7 +195,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels(tools) + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) result = [] @@ -198,10 +204,11 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + continue user_tool_provider.tools = [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) - ) + ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, [])) ] result.append(user_tool_provider) @@ -232,7 +239,7 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -241,13 +248,19 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too if db_tool is None: raise ValueError(f"Tool {workflow_tool_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + return { "name": db_tool.name, "label": db_tool.label, @@ -257,9 +270,9 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @@ -272,7 +285,7 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() @@ -281,12 +294,17 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ if db_tool is None: raise ValueError(f"Tool {workflow_app_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_app_id} not found") return { "name": db_tool.name, @@ -297,14 +315,14 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @classmethod - def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: """ List workflow tool provider tools. :param user_id: the user id @@ -312,7 +330,7 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -322,9 +340,8 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") - return [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) - ) - ] + return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))] diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index d7ccc964cb70f8..f698ed3084bdac 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,6 +16,8 @@ class WebConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], @@ -21,26 +26,29 @@ def pagination_by_last_id( pinned: Optional[bool] = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") include_ids = None exclude_ids = None - if pinned is not None: - pinned_conversations = ( - db.session.query(PinnedConversation) - .filter( + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) .order_by(PinnedConversation.created_at.desc()) - .all() ) - pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + pinned_conversation_ids = session.scalars(stmt).all() + if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids return ConversationService.pagination_by_last_id( + session=session, app_model=app_model, user=user, last_id=last_id, @@ -53,6 +61,8 @@ def pagination_by_last_id( @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( @@ -83,6 +93,8 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( diff --git a/api/services/website_service.py b/api/services/website_service.py index 230f5d78152f39..1ad7d0399d6edf 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,8 +1,9 @@ import datetime import json +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from core.helper import encrypter from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp @@ -23,9 +24,9 @@ def document_create_args_validate(cls, args: dict): @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider") + provider = args.get("provider", "") url = args.get("url") - options = args.get("options") + options = args.get("options", "") credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key @@ -164,16 +165,18 @@ def get_crawl_status(cls, job_id: str, provider: str) -> dict: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any if provider == "firecrawl": file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) @@ -183,22 +186,17 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str if data: for item in data: if item.get("source_url") == url: - return item + return dict(item) return None elif provider == "jinareader": - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) - elif not job_id: + if not job_id: response = requests.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) if response.json().get("code") != 200: raise ValueError("Failed to crawl") - return response.json().get("data") + return dict(response.json().get("data", {})) else: api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) response = requests.post( @@ -218,12 +216,13 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str data = response.json().get("data", {}) for item in data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: - return item.get("data", {}) + return dict(item.get("data", {})) + return None else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 90b5cc48362f3b..2b0d57bdfdeda3 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional from core.app.app_config.entities import ( DatasetEntity, @@ -101,7 +101,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = {"nodes": [], "edges": []} + graph: dict[str, Any] = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -118,7 +118,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: graph["nodes"].append(start_node) # convert to http request node - external_data_variable_node_mapping = {} + external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, @@ -199,15 +199,16 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) - elif app_mode == AppMode.CHAT: + elif app_mode_enum == AppMode.CHAT: app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode == AppMode.COMPLETION: + elif app_mode_enum == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -302,7 +303,7 @@ def _convert_to_http_request_node( nodes.append(http_request_node) # append code node for response body parsing - code_node = { + code_node: dict[str, Any] = { "id": f"code_{index}", "position": None, "data": { @@ -401,6 +402,7 @@ def _convert_to_llm_node( ) role_prefix = None + prompts: Any = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d8ee323908a844..4343596a236f5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -92,7 +94,7 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 37d7d0937cd492..ea8192edde35cc 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional +from typing import Any, Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -11,6 +11,9 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry @@ -225,7 +228,7 @@ def run_draft_workflow_node( user_inputs=user_inputs, user_id=account.id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) node_run_result: NodeRunResult | None = None for event in generator: if isinstance(event, RunCompletedEvent): @@ -237,8 +240,35 @@ def run_draft_workflow_node( if not node_run_result: raise ValueError("Node run failed with no run result") - - run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args: dict[str, Any] = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance @@ -260,7 +290,6 @@ def run_draft_workflow_node( workflow_node_execution.created_by = account.id workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - if run_succeeded and node_run_result: # create workflow node execution inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None @@ -277,7 +306,11 @@ def run_draft_workflow_node( workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.error = node_run_result.error else: # create workflow node execution workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value @@ -305,7 +338,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow - new_app = workflow_converter.convert_to_workflow( + new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, account=account, name=args.get("name", "Default Name"), diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 8fcb12b1cb9664..7637b31454e556 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,4 @@ -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from extensions.ext_database import db @@ -29,6 +29,7 @@ def get_tenant_info(cls, tenant: Tenant): .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) + assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 09be6612160471..50bb2b6e634fba 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 25c55bcfafe11c..aab21a44109975 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fa7e5ac9190f3c..06162b02d60f8b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f0f6b32b06c78c..a6a598ce4b6bca 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index a2f49135139b08..26bf1c7c9fa32e 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 0bdcd0eccd7f72..b42af0c7faf67e 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index b685d84d07ad28..8c675feaa6e06f 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e44b938..26ae9f8736d79a 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -4,7 +4,7 @@ import uuid import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import func from core.indexing_runner import IndexingRunner @@ -58,12 +58,13 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 + segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] for segment in content: - content = segment["content"] + content_str = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + segment_hash = helper.generate_text_hash(content_str) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id) @@ -90,6 +91,7 @@ def batch_create_segment_to_index_task( word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed # update document word count dataset_document.word_count += word_count_change db.session.add(dataset_document) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index a555fb28746697..d9278c03793877 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -71,6 +71,8 @@ def clean_dataset_task( image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 4d328643bfa165..3e80dd13771802 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -44,6 +44,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 75d9e031306381..f5d6406d9cc04f 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 315b01f157bf13..dfa053a43cbc61 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,7 +4,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index cfc54920e23caa..b025509aebe674 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index c3e0ea5d9fbb77..45a612c74550cd 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 15e1e50076e8c9..f30a1cc7acfd6c 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 18316913932874..ac4e81f95d127e 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 734dd2478a9847..21b571b6cb5bd4 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 1a52a6636b1d17..5f1e9a892f54e3 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index f4c3dbd2e2860c..6db2620eb6eef0 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") # check document limit features = FeatureService.get_features(dataset.tenant_id) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 12639db9392677..2f6eb7b82a0633 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/external_document_indexing_task.py b/api/tasks/external_document_indexing_task.py index 6fc719ae8d085a..a45b3030bf253a 100644 --- a/api/tasks/external_document_indexing_task.py +++ b/api/tasks/external_document_indexing_task.py @@ -3,9 +3,9 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore -from core.indexing_runner import DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedError from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, ExternalKnowledgeApis @@ -68,11 +68,9 @@ def external_document_indexing_task( settings = ExternalDatasetService.get_external_knowledge_api_settings( json.loads(external_knowledge_api.settings) ) - # assemble headers - headers = ExternalDatasetService.assembling_headers(settings.authorization, settings.headers) # do http request - response = ExternalDatasetService.process_external_api(settings, headers, process_parameter, files) + response = ExternalDatasetService.process_external_api(settings, files) job_id = response.json().get("job_id") if job_id: # save job_id to dataset @@ -86,7 +84,7 @@ def external_document_indexing_task( fg="green", ) ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index d78fc2b8915520..5dc935548f90b8 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index c7dfb9bf6063ff..3094527fd40945 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from configs import dify_config diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 8596ca07cfcee3..d5be94431b6221 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 34c62dc9237fc0..bb3b9e17ead6d2 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task +from celery import shared_task # type: ignore from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 934eb7430c90c3..b603d689ba9d8e 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 66f78636ecca60..c3910e2be3a499 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ from collections.abc import Callable import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1909eaf3418517..4ba6d1a83e32ae 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 73471fd6e77c9b..485caa5152ea78 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): Usage: retry_document_indexing_task.delay(dataset_id, document_id) """ - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit @@ -55,29 +58,31 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): document = ( db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 1d2a338c831764..5d6b069cf44919 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit @@ -52,29 +54,31 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 64f2884c4b828c..57fba317638de8 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -1,6 +1,6 @@ from typing import Any -import toml +import toml # type: ignore def load_api_poetry_configs() -> dict[str, Any]: @@ -38,7 +38,7 @@ def test_group_dependencies_version_operator(): ) -def test_duplicated_dependency_crossing_groups(): +def test_duplicated_dependency_crossing_groups() -> None: all_dependency_names: list[str] = [] for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py index 6371694694653e..5e3ee6bedc7ebb 100644 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from app_fixture import app, mock_user +from app_fixture import mock_user # type: ignore def test_post_requires_login(app): diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 402bd9c2c21f69..b90f8b444477d5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,16 +1,16 @@ from collections.abc import Generator +from unittest.mock import MagicMock -import google.generativeai.types.generation_types as generation_config_types +import google.generativeai.types.generation_types as generation_config_types # type: ignore import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel -from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = "" +from extensions import ext_redis class MockGoogleResponseClass: @@ -45,7 +45,7 @@ def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod - def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: + def generate_content_stream() -> MockGoogleResponseClass: return MockGoogleResponseClass() def generate_content( @@ -57,11 +57,6 @@ def generate_content( stream: bool = False, **kwargs, ) -> GenerateContentResponse: - global current_api_key - - if len(current_api_key) < 16: - raise Exception("Invalid API key") - if stream: return MockGoogleClass.generate_content_stream() @@ -75,33 +70,29 @@ def generative_response_text(self) -> str: def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - def make_client(self: _ClientManager, name: str): - global current_api_key - if name.endswith("_async"): - name = name.split("_")[0] - cls = getattr(glm, name.title() + "ServiceAsyncClient") - else: - cls = getattr(glm, name.title() + "ServiceClient") +def mock_configure(api_key: str): + if len(api_key) < 16: + raise Exception("Invalid API key") + + +class MockFileState: + def __init__(self): + self.name = "FINISHED" - # Attempt to configure using defaults. - if not self.client_config: - configure() - client_options = self.client_config.get("client_options", None) - if client_options: - current_api_key = client_options.api_key +class MockGoogleFile: + def __init__(self, name: str = "mock_file_name"): + self.name = name + self.state = MockFileState() - def nop(self, *args, **kwargs): - pass - original_init = cls.__init__ - cls.__init__ = nop - client: glm.GenerativeServiceClient = cls(**self.client_config) - cls.__init__ = original_init +def mock_get_file(name: str) -> MockGoogleFile: + return MockGoogleFile(name) - if not self.default_metadata: - return client + +def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile: + return MockGoogleFile() @pytest.fixture @@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) - monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client) + monkeypatch.setattr("google.generativeai.configure", mock_configure) + monkeypatch.setattr("google.generativeai.get_file", mock_get_file) + monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file) yield monkeypatch.undo() + + +@pytest.fixture +def setup_mock_redis() -> None: + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.setex = MagicMock(return_value=None) + ext_redis.redis_client.exists = MagicMock(return_value=True) diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index 97038ef5963e87..4de52514408a06 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -2,7 +2,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient +from huggingface_hub import InferenceClient # type: ignore from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 9ee76c935c9873..77c7e7f5e4089c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -3,15 +3,15 @@ from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.inference._text_generation import ( # type: ignore Details, StreamDetails, TextGenerationResponse, TextGenerationStreamResponse, Token, ) -from huggingface_hub.utils import BadRequestError +from huggingface_hub.utils import BadRequestError # type: ignore class MockHuggingfaceChatClass: diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py index 6a25398cbf069a..4e00660a29162f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py @@ -6,7 +6,7 @@ # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from nomic import embed +from nomic import embed # type: ignore def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 5f7dad50c10f11..e2abaa52b939a6 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -6,14 +6,14 @@ from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.sessions import Session -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, RESTfulGenerateModelHandle, RESTfulRerankModelHandle, ) -from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore class MockXinferenceClass: @@ -21,13 +21,13 @@ def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHa if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): raise RuntimeError("404 Not Found") - if "generate" == model_uid: + if model_uid == "generate": return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "chat" == model_uid: + if model_uid == "chat": return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "embedding" == model_uid: + if model_uid == "embedding": return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "rerank" == model_uid: + if model_uid == "rerank": return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) raise RuntimeError("404 Not Found") diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 8f50ebf7a6d03f..216c50a1823c8d 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -199,7 +199,9 @@ def test_invoke_chat_model_with_vision(setup_openai_mock): data="Hello World!", ), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ), diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 2877fa150764eb..65357be6586143 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -13,7 +13,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.google import setup_google_mock +from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) @@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock): @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) -def test_invoke_chat_model_with_vision(setup_google_mock): +def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis): model = GoogleLargeLanguageModel() result = model.invoke( @@ -109,7 +109,9 @@ def test_invoke_chat_model_with_vision(setup_google_mock): content=[ TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ), @@ -124,7 +126,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock): @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) -def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): +def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis): model = GoogleLargeLanguageModel() result = model.invoke( @@ -136,7 +138,9 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): content=[ TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ), @@ -145,7 +149,9 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): content=[ TextPromptMessageContent(data="what about now?"), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAABAAAAAQBPJcTWAAADl0lEQVR4nC3Uf0zUdRjA8S9W6w//bGs1DUd5RT+gIY0oYeEqY0QCy5EbAnF4IEgyAnGuCBANWOjih6YOlK0BbtLAX+iAENFgUBLMkzs8uDuO+wEcxx3cgdx9v3fvvn/0x+v5PM+z56/n2T6CIAgIQUEECVsICnqOoC0v8PyLW3n5lW28GhLG9hAFwYowdoRsJ+Tzv3hdEcpOxVvsfDscheI1BIXKy5t7OwiPiCI8IZaIL+OISPKxK/IDdiU6ifwqjqj4WKISP5VN8mHSFNHJA7KnfJQYh7A7+g1i9hXw2dcX2JuSxhcJnxCfnEJ8ygESqtfYl3qA5O/1pKaX8E2Rn7R0JWnKXFkRaX0OhIOqUtJVRWQoj5ChyiOjb4XMQ0fIVB0lM6eEzMO5ZN5x8W1xD1nZh1Fm55OtzOdQTgEqZR6CSi5UjSI5hTnk3bWSX/gj+ccaKCgspaDkNIWlpygc3OTYtZc4fqKcE5Vn+eFkDWUp8ZS1ryOUn66lvGmCyt/8nLwxTlXZcapqL1Nd10B1Uy01FbnUnFVS+2sLvzTWUXfRRMOAgcb6KhovdSA0XnHRdL6Zcy1/0lyTS3NfgJbWNq6cu0nrPyu0FSlpu9pF21037ZFhXLtYT+eNIbp61+jq70bofv8drvf0c2vQz+3O3+nRrNI78JD+/psMfLefe0MG7p+a5v6tP3g48ojhC7mMXP2Y0YoZRitnEcbkMPaglzEnPAoNZrw4hXH1LBOtOiYfa3gcugO1+gnqZwGeaHRMTcyhaduKRjOBxiJfQSsnWq0W7YwVrd3PtH6BaeMST40adJ3V6OwBZlR7mNUvMWswYsiKxTA1gWHOgsGiRzCmRGOcW8QoD855JObWJUxmHSb5nfd4Mc+ZMFv1MjtmuWepSMNiMmAxz2LN2o1gbdmDdV6NdVnE1p6EzajHZp7BtjCLbSnAgsMtE1k8H8OiwyuTWPL4sLduwz5vRLA7XCzbLCw7PTiswzgWJnBsijhNwzhtw6xmRLLmdLC27sU9dBC324un/iieSyF4rPIS1/8eZOOego0NL898Epv14Wz2nMHrsOB12/Glh+Mrfg/fqgufKCHmxSC21SE6JxFdKwjihhFxw4O4aUf0bSKVRyN1pyKNXEcaDUbS3EZan5Sp/zeFtLGO5LUiSRKCJAXwZ0bg73oXv+kBfrsOv8uOXxIJ/JRG4N/9sjME1B3QXAjzd8CqhqWfkT8C4T8Z5+ciRtwo8gAAAABJRU5ErkJggg==", ), ] ), diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 58a1339f506458..979751afceaca4 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -160,7 +160,9 @@ def test_invoke_completion_model_with_vision(): data="What is this in this picture?", ), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ) @@ -191,7 +193,9 @@ def test_invoke_chat_model_with_vision(): data="What is this in this picture?", ), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ) diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index 41c99f68756eb5..9e83b9d434359d 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -139,7 +139,9 @@ def test_invoke_chat_model_with_vision(setup_openai_mock): data="Hello World!", ), ImagePromptMessageContent( - data="" + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", ), ] ), diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py index 2dcfb92c63fee2..d37fcf897fc3a8 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -1,6 +1,6 @@ import os -import dashscope +import dashscope # type: ignore import pytest from core.model_runtime.entities.rerank_entities import RerankResult diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9ac2f..2860739f0e30b3 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restful import Api, Resource # type: ignore app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 09729a961eff33..1bd75b91f745d6 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -34,9 +34,9 @@ def test_api_tool(setup_http_mock): response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert "/p_param" == response.request.url.path - assert b"query_param=q_param" == response.request.url.query - assert "h_param" == response.request.headers.get("header_param") - assert "application/json" == response.request.headers.get("content-type") - assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert response.request.url.path == "/p_param" + assert response.request.url.query == b"query_param=q_param" + assert response.request.headers.get("header_param") == "h_param" + assert response.request.headers.get("content-type") == "application/json" + assert response.request.headers.get("cookie") == "cookie_param=c_param" assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 0ea61369c0304e..4af35a8befcaf8 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -4,11 +4,11 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient -from pymochow.model.database import Database -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState -from pymochow.model.schema import HNSWParams, VectorIndex -from pymochow.model.table import Table +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore from requests.adapters import HTTPAdapter diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 61d6ed16560c09..68a1e290adc120 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -4,12 +4,12 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient -from tcvectordb.model.database import Collection, Database -from tcvectordb.model.document import Document, Filter -from tcvectordb.model.enum import ReadConsistency -from tcvectordb.model.index import Index -from xinference_client.types import Embedding +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 0f40337feba6ee..3ad72e55501f58 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Collection, Data, DistanceType, diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py index f8f43ba6ef8ab3..0a26d3ea1c9987 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -7,9 +7,10 @@ class Config: - SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070") SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") - SEARCH_PWD = env.str("SEARCH_PWD", "PWD") + SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN") + USING_UGC = env.bool("USING_UGC", True) class TestLindormVectorStore(AbstractVectorTest): @@ -31,5 +32,27 @@ def get_ids_by_metadata_field(self): assert ids[0] == self.example_doc_id -def test_lindorm_vector(setup_mock_redis): +class TestLindormVectorStoreUGC(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name="ugc_index_test", + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + using_ugc=Config.USING_UGC, + ), + routing_value=self.collection_name, + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector_ugc(setup_mock_redis): TestLindormVectorStore().run_all_tests() + TestLindormVectorStoreUGC().run_all_tests() diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 2a5320c7d5e752..4c83c66bff5057 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -12,11 +12,11 @@ def tidb_vector(): return TiDBVector( collection_name="test_collection", config=TiDBVectorConfig( - host="xxx.eu-central-1.xxx.aws.tidbcloud.com", - port="4000", - user="xxx.root", - password="xxxxxx", - database="dify", + host="localhost", + port=4000, + user="root", + password="", + database="test", program_name="langgenius/dify", ), ) @@ -27,35 +27,14 @@ def __init__(self, vector): super().__init__() self.vector = vector - def text_exists(self): - exist = self.vector.text_exists(self.example_doc_id) - assert exist == False - - def search_by_vector(self): - hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) - assert len(hits_by_vector) == 0 - def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) - assert len(ids) == 0 + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert len(ids) == 1 -def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session): +def test_tidb_vector(setup_mock_redis, tidb_vector): TiDBVectorTest(vector=tidb_vector).run_all_tests() - - -@pytest.fixture -def mock_session(): - with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session: - yield mock_session - - -@pytest.fixture -def setup_tidbvector_mock(tidb_vector, mock_session): - with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"): - with patch.object(tidb_vector._engine, "connect"): - yield tidb_vector diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 9eea63f722e51f..0507fc707564dd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -384,7 +384,7 @@ def test_mock_404(setup_http_mock): assert result.outputs is not None resp = result.outputs - assert 404 == resp.get("status_code") + assert resp.get("status_code") == 404 assert "Not Found" in resp.get("body", "") diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 0eb310a51a335b..efa9ea89794b92 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -37,7 +37,11 @@ def test_dify_config_undefined_entry(example_env_file): assert config["LOG_LEVEL"] == "INFO" +# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. +# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_dify_config(example_env_file): + # clear system environment variables + os.environ.clear() # load dotenv file with pydantic-settings config = DifyConfig(_env_file=example_env_file) @@ -55,6 +59,8 @@ def test_dify_config(example_env_file): # annotated field with configured value assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 0c264c15a03593..426557c7161419 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -2,6 +2,8 @@ from pydantic import ValidationError from core.variables import ( + ArrayFileVariable, + ArrayVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -81,3 +83,8 @@ def test_variable_to_object(): assert var.to_object() == 3.14 var = SecretVariable(name="secret", value="secret_value") assert var.to_object() == "secret_value" + + +def test_array_file_variable_is_array_variable(): + var = ArrayFileVariable(name="files", value=[]) + assert isinstance(var, ArrayVariable) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 7d19cff3e8ece6..ee0f7672f8c814 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,7 @@ import pytest +from configs import dify_config from core.app.app_config.entities import ModelConfigEntity from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory @@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): model_config_mock, _, messages, inputs, context = get_chat_model_args + dify_config.MULTIMODAL_SEND_FORMAT = "url" files = [ File( @@ -134,13 +136,16 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="", ) ] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: - mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url)) + mock_get_encoded_string.return_value = ImagePromptMessageContent( + url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" + ) prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_template=messages, inputs=inputs, diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 4edbc01cc778e8..e02d882780900f 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,34 +1,9 @@ import json -from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow -def test_file_loads_and_dumps(): - file = File( - id="file1", - tenant_id="tenant1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image1.jpg", - ) - - file_dict = file.model_dump() - assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY - assert file_dict["type"] == file.type.value - assert isinstance(file_dict["type"], str) - assert file_dict["transfer_method"] == file.transfer_method.value - assert isinstance(file_dict["transfer_method"], str) - assert "_extra_config" not in file_dict - - file_obj = File.model_validate(file_dict) - assert file_obj.id == file.id - assert file_obj.tenant_id == file.tenant_id - assert file_obj.type == file.type - assert file_obj.transfer_method == file.transfer_method - assert file_obj.remote_url == file.remote_url - - def test_file_to_dict(): file = File( id="file1", @@ -36,10 +11,11 @@ def test_file_to_dict(): type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="storage_key", ) file_dict = file.to_dict() - assert "_extra_config" not in file_dict + assert "_storage_key" not in file_dict assert "url" in file_dict diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 9f1ba7b6af9c80..b7d8f69e8c52ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -488,14 +488,12 @@ def test_run_branch(mock_close, mock_remove): items = [] generator = graph_engine.run() for item in generator: - # print(type(item), item) items.append(item) assert len(items) == 10 assert items[3].route_node_state.node_id == "if-else-1" assert items[4].route_node_state.node_id == "if-else-1" assert isinstance(items[5], NodeRunStreamChunkEvent) - assert items[5].chunk_content == "1 " assert isinstance(items[6], NodeRunStreamChunkEvent) assert items[6].chunk_content == "takato" assert items[7].route_node_state.node_id == "answer-1" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 7c19de60783928..58b910e17bae4a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -48,7 +48,7 @@ def test_executor_with_json_body_and_number_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} + assert executor.params == [] assert executor.json == {"number": 42} assert executor.data is None assert executor.files is None @@ -101,7 +101,7 @@ def test_executor_with_json_body_and_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} + assert executor.params == [] assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} assert executor.data is None assert executor.files is None @@ -156,7 +156,7 @@ def test_executor_with_json_body_and_nested_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} + assert executor.params == [] assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} assert executor.data is None assert executor.files is None @@ -195,7 +195,7 @@ def test_extract_selectors_from_template_with_newline(): variable_pool=variable_pool, ) - assert executor.params == {"test": "line1\nline2"} + assert executor.params == [("test", "line1\nline2")] def test_executor_with_form_data(): @@ -244,7 +244,7 @@ def test_executor_with_form_data(): assert executor.url == "https://api.example.com/upload" assert "Content-Type" in executor.headers assert "multipart/form-data" in executor.headers["Content-Type"] - assert executor.params == {} + assert executor.params == [] assert executor.json is None assert executor.files is None assert executor.content is None @@ -265,3 +265,72 @@ def test_executor_with_form_data(): assert "Hello, World!" in raw_request assert "number_field" in raw_request assert "42" in raw_request + + +def test_init_headers(): + def create_executor(headers: str) -> Executor: + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers=headers, + params="", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + + executor = create_executor("aa\n cc:") + executor._init_headers() + assert executor.headers == {"aa": "", "cc": ""} + + executor = create_executor("aa:bb\n cc:dd") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + executor = create_executor("aa:bb\n cc:dd\n") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + executor = create_executor("aa:bb\n\n cc : dd\n\n") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + +def test_init_params(): + def create_executor(params: str) -> Executor: + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params=params, + authorization=HttpRequestNodeAuthorization(type="no-auth"), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + + # Test basic key-value pairs + executor = create_executor("key1:value1\nkey2:value2") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] + + # Test empty values + executor = create_executor("key1:\nkey2:") + executor._init_params() + assert executor.params == [("key1", ""), ("key2", "")] + + # Test duplicate keys (which is allowed for params) + executor = create_executor("key1:value1\nkey1:value2") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key1", "value2")] + + # Test whitespace handling + executor = create_executor(" key1 : value1 \n key2 : value2 ") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] + + # Test empty lines and extra whitespace + executor = create_executor("key1:value1\n\nkey2:value2\n\n") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 741a3a1894625c..97bacada74572d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -14,18 +14,10 @@ HttpRequestNodeBody, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.executor import _plain_text_to_dict from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType -def test_plain_text_to_dict(): - assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""} - assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"} - assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"} - assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"} - - def test_http_request_node_binary_file(monkeypatch): data = HttpRequestNodeData( title="test", @@ -59,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) @@ -146,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 9a24d35a1fcdae..76db42ef106dfa 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -18,11 +18,11 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute @@ -158,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node): filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ) llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) @@ -174,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ), File( id="2", @@ -182,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", + storage_key="", ), ] llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) @@ -225,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + storage_key="", ) ] fake_query = faker.sentence() prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, + sys_query=fake_query, + sys_files=files, context=None, memory=None, model_config=model_config, @@ -249,8 +253,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Setup dify config - dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" - dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" + dify_config.MULTIMODAL_SEND_FORMAT = "url" # Generate fake values for prompt template fake_assistant_prompt = faker.sentence() @@ -285,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): test_scenarios = [ LLMNodeTestScenario( description="No files", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], features=[], vision_enabled=False, vision_detail=None, @@ -320,14 +323,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), LLMNodeTestScenario( description="User files", - user_query=fake_query, - user_files=[ + sys_query=fake_query, + sys_files=[ File( tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", ) ], vision_enabled=True, @@ -361,15 +367,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): UserPromptMessage( content=[ TextPromptMessageContent(data=fake_query), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), ] ), ], ), LLMNodeTestScenario( description="Prompt template with variable selector of File", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], vision_enabled=False, vision_detail=fake_vision_detail, features=[ModelFeature.VISION], @@ -384,7 +392,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): expected_messages=[ UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), ] ), ] @@ -397,6 +407,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", ) }, ), @@ -411,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=scenario.user_query, - user_files=scenario.user_files, + sys_query=scenario.sys_query, + sys_files=scenario.sys_files, context=fake_context, memory=memory, model_config=model_config, @@ -429,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): assert ( prompt_messages == scenario.expected_messages ), f"Message content mismatch in scenario: {scenario.description}" + + +def test_handle_list_messages_basic(llm_node): + messages = [ + LLMNodeChatModelMessage( + text="Hello, {#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ] + context = "world" + jinja2_variables = [] + variable_pool = llm_node.graph_runtime_state.variable_pool + vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH + + result = llm_node._handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + ) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == [TextPromptMessageContent(data="Hello, world")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index 8e39445baf5490..21bb857353262c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel): """Test scenario for LLM node testing.""" description: str = Field(..., description="Description of the test scenario") - user_query: str = Field(..., description="User query input") - user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + sys_query: str = Field(..., description="User query input") + sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py new file mode 100644 index 00000000000000..2d74be9da9a96c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -0,0 +1,508 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphRunPartialSucceededEvent, + NodeRunExceptionEvent, + NodeRunStreamChunkEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.graph_engine import GraphEngine +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class ContinueOnErrorTestHelper: + @staticmethod + def get_code_node( + code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} + ): + """Helper method to create a code node configuration""" + node = { + "id": "node", + "data": { + "outputs": {"result": {"type": "number"}}, + "error_strategy": error_strategy, + "title": "code", + "variables": [], + "code_language": "python3", + "code": "\n".join([line[4:] for line in code.split("\n")]), + "type": "code", + **retry_config, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_http_node( + error_strategy: str = "fail-branch", + default_value: dict | None = None, + authorization_success: bool = False, + retry_config: dict = {}, + ): + """Helper method to create a http node configuration""" + authorization = ( + { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + } + if authorization_success + else { + "type": "api-key", + # missing config field + } + ) + node = { + "id": "node", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": authorization, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + "type": "http-request", + "error_strategy": error_strategy, + **retry_config, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a http node configuration""" + node = { + "id": "node", + "data": { + "type": "http-request", + "title": "HTTP Request", + "desc": "", + "variables": [], + "method": "get", + "url": "https://api.github.com/issues", + "authorization": {"type": "no-auth", "config": None}, + "headers": "", + "params": "", + "body": {"type": "none", "data": []}, + "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a tool node configuration""" + node = { + "id": "node", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], + } + }, + "type": "tool", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a llm node configuration""" + node = { + "id": "node", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): + """Helper method to create a graph engine instance for testing""" + graph = Graph.init(graph_config=graph_config) + variable_pool = { + "system_variables": { + SystemVariableKey.QUERY: "clear", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + "user_inputs": user_inputs or {"uid": "takato"}, + } + + return GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + +FAIL_BRANCH_EDGES = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-true-success-target", + "source": "node", + "target": "success", + "sourceHandle": "source", + }, + { + "id": "node-false-error-target", + "source": "node", + "target": "error", + "sourceHandle": "fail-branch", + }, +] + + +def test_code_default_value_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_code_node( + error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_code_fail_branch_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_code_node(error_code), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events + ) + + +def test_http_node_default_value_continue_on_error(): + """Test HTTP node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} + for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_default_value_continue_on_error(): + """Test tool node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_tool_node( + "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_tool_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_default_value_continue_on_error(): + """Test LLM node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_llm_node( + "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_fail_branch_continue_on_error(): + """Test LLM node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_status_code_error_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_variable_pool_error_type_variable(): + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + list(graph_engine.run()) + error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) + error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) + assert error_message != None + assert error_type.value == "HTTPResponseCodeError" + + +def test_no_node_in_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES[:-1], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d964d0e3529c69..41e2c5d48468f6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -248,6 +248,7 @@ def test_array_file_contains_file_name(): transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", + storage_key="", ), ], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d20dfc5b311698..36116d35404cf5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", + storage_key="", ), File( filename="document1.pdf", @@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", + storage_key="", ), File( filename="image2.png", @@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", + storage_key="", ), File( filename="audio1.mp3", @@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", + storage_key="", ), ] variable = ArrayFileSegment(value=files) @@ -130,6 +134,7 @@ def test_get_file_extract_string_func(): mime_type="text/plain", remote_url="https://example.com/test_file.txt", related_id="test_related_id", + storage_key="", ) # Test each case @@ -150,6 +155,7 @@ def test_get_file_extract_string_func(): mime_type=None, remote_url=None, related_id="test_related_id", + storage_key="", ) assert _get_file_extract_string_func(key="name")(empty_file) == "" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py new file mode 100644 index 00000000000000..c232875ce57d3a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -0,0 +1,73 @@ +from core.workflow.graph_engine.entities.event import ( + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunSucceededEvent, + NodeRunRetryEvent, +) +from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + + +def test_retry_default_value_partial_success(): + """retry default value node with partial success status""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", + [{"key": "result", "type": "string", "value": "http node got error response"}], + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert events[-1].outputs == {"answer": "http node got error response"} + assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) + assert len(events) == 11 + + +def test_retry_failed(): + """retry failed with success status""" + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + None, + None, + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert any(isinstance(e, GraphRunFailedEvent) for e in events) + assert len(events) == 8 diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9ea6acac17132d..efbcdc760c6995 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -19,6 +19,7 @@ def file(): related_id="test_related_id", remote_url="test_url", filename="test_file.txt", + storage_key="", ) diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py index 27e1c0ad85029b..4f6d8a2f54a4fd 100644 --- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from oss2 import Bucket -from oss2.models import GetObjectResult, PutObjectResult +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py index a1eaaab9c35f16..bb3c9716c3c48b 100644 --- a/api/tests/unit_tests/oss/__mock/base.py +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -6,13 +6,17 @@ def get_example_folder() -> str: - return "/dify" + return "~/dify" def get_example_bucket() -> str: return "dify" +def get_opendal_bucket() -> str: + return "./dify" + + def get_example_filename() -> str: return "test.txt" @@ -22,14 +26,14 @@ def get_example_data() -> bytes: def get_example_filepath() -> str: - return "/test" + return "~/test" class BaseStorageTest: @pytest.fixture(autouse=True) - def setup_method(self): + def setup_method(self, *args, **kwargs): """Should be implemented in child classes to setup specific storage.""" - self.storage = BaseStorage() + self.storage: BaseStorage def test_save(self): """Test saving data.""" diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index 5189b68e87132a..c77c5b08f37d15 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client -from qcloud_cos.streambody import StreamBody +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 649d93a20261d3..88df59f91c3071 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py index 65d31352bd3437..380134bc46d02e 100644 --- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from oss2 import Auth +from oss2 import Auth # type: ignore from extensions.storage.aliyun_oss_storage import AliyunOssStorage from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock diff --git a/api/tests/unit_tests/oss/local/test_local_fs.py b/api/tests/unit_tests/oss/local/test_local_fs.py deleted file mode 100644 index 03ce7d2450a911..00000000000000 --- a/api/tests/unit_tests/oss/local/test_local_fs.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Generator - -import pytest - -from extensions.storage.local_fs_storage import LocalFsStorage -from tests.unit_tests.oss.__mock.base import ( - BaseStorageTest, - get_example_folder, -) -from tests.unit_tests.oss.__mock.local import setup_local_fs_mock - - -class TestLocalFS(BaseStorageTest): - @pytest.fixture(autouse=True) - def setup_method(self, setup_local_fs_mock): - """Executed before each test method.""" - self.storage = LocalFsStorage() - self.storage.folder = get_example_folder() diff --git a/api/tests/unit_tests/oss/opendal/__init__.py b/api/tests/unit_tests/oss/opendal/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py new file mode 100644 index 00000000000000..6acec6e579d2a6 --- /dev/null +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -0,0 +1,85 @@ +from collections.abc import Generator +from pathlib import Path + +import pytest + +from extensions.storage.opendal_storage import OpenDALStorage +from tests.unit_tests.oss.__mock.base import ( + get_example_data, + get_example_filename, + get_opendal_bucket, +) + + +class TestOpenDAL: + @pytest.fixture(autouse=True) + def setup_method(self, *args, **kwargs): + """Executed before each test method.""" + self.storage = OpenDALStorage( + scheme="fs", + root=get_opendal_bucket(), + ) + + @pytest.fixture(scope="class", autouse=True) + def teardown_class(self, request): + """Clean up after all tests in the class.""" + + def cleanup(): + folder = Path(get_opendal_bucket()) + if folder.exists() and folder.is_dir(): + for item in folder.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + item.rmdir() + folder.rmdir() + + return cleanup() + + def test_save_and_exists(self): + """Test saving data and checking existence.""" + filename = get_example_filename() + data = get_example_data() + + assert not self.storage.exists(filename) + self.storage.save(filename, data) + assert self.storage.exists(filename) + + def test_load_once(self): + """Test loading data once.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + loaded_data = self.storage.load_once(filename) + assert loaded_data == data + + def test_load_stream(self): + """Test loading data as a stream.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + generator = self.storage.load_stream(filename) + assert isinstance(generator, Generator) + assert next(generator) == data + + def test_download(self): + """Test downloading data to a file.""" + filename = get_example_filename() + filepath = str(Path(get_opendal_bucket()) / filename) + data = get_example_data() + + self.storage.save(filename, data) + self.storage.download(filename, filepath) + + def test_delete(self): + """Test deleting a file.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + assert self.storage.exists(filename) + + self.storage.delete(filename) + assert not self.storage.exists(filename) diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bda42f..d289751800633a 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig +from qcloud_cos import CosConfig # type: ignore from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 5afbc9e8b4cb18..04988e85d85881 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,5 +1,5 @@ import pytest -from tos import TosClientV2 +from tos import TosClientV2 # type: ignore from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d57f80..8d645487278a5f 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from yaml import YAMLError +from yaml import YAMLError # type: ignore from core.tools.utils.yaml_utils import load_yaml_file diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py new file mode 100644 index 00000000000000..08adc9ebe999b1 --- /dev/null +++ b/dev/pytest/pytest_config_tests.py @@ -0,0 +1,111 @@ +import yaml # type: ignore +from dotenv import dotenv_values +from pathlib import Path + +BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { + "APP_MAX_EXECUTION_TIME", + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "CODE_EXECUTION_API_KEY", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OCI_ACCESS_KEY", + "OCI_BUCKET_NAME", + "OCI_ENDPOINT", + "OCI_REGION", + "OCI_SECRET_KEY", + "REDIS_DB", + "RESEND_API_URL", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SENTRY_DSN", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "UPSTASH_VECTOR_TOKEN", + "UPSTASH_VECTOR_URL", + "USING_UGC_INDEX", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OPENDAL_FS_ROOT", + "OPENDAL_S3_ACCESS_KEY_ID", + "OPENDAL_S3_BUCKET", + "OPENDAL_S3_ENDPOINT", + "OPENDAL_S3_REGION", + "OPENDAL_S3_ROOT", + "OPENDAL_S3_SECRET_ACCESS_KEY", + "OPENDAL_S3_SERVER_SIDE_ENCRYPTION", + "PGVECTOR_MAX_CONNECTION", + "PGVECTOR_MIN_CONNECTION", + "PGVECTO_RS_DATABASE", + "PGVECTO_RS_HOST", + "PGVECTO_RS_PASSWORD", + "PGVECTO_RS_PORT", + "PGVECTO_RS_USER", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SCARF_NO_ANALYTICS", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "STORAGE_OPENDAL_SCHEME", + "SUPABASE_API_KEY", + "SUPABASE_BUCKET_NAME", + "SUPABASE_URL", + "USING_UGC_INDEX", + "VIKINGDB_CONNECTION_TIMEOUT", + "VIKINGDB_SOCKET_TIMEOUT", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) +DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys()) +DOCKER_COMPOSE_CONFIG_SET = set() + +with open(Path("docker") / Path("docker-compose.yaml")) as f: + DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys()) + + +def test_yaml_config(): + # python set == operator is used to compare two sets + DIFF_API_WITH_DOCKER = ( + API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER: + print( + f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" + ) + raise Exception("API and Docker config sets are different") + DIFF_API_WITH_DOCKER_COMPOSE = ( + API_CONFIG_SET + - DOCKER_COMPOSE_CONFIG_SET + - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER_COMPOSE: + print( + f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" + ) + raise Exception("API and Docker Compose config sets are different") + print("All tests passed!") + + +if __name__ == "__main__": + test_yaml_config() diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 02a9f492797d22..c68a94c79bf800 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -14,3 +14,4 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/upstash \ api/tests/integration_tests/vdb/couchbase \ api/tests/integration_tests/vdb/oceanbase \ + api/tests/integration_tests/vdb/tidb_vector \ diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index e7a2daf9cdae58..1cff58be7f661d 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.13.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.13.1 + image: langgenius/dify-api:0.14.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.13.1 + image: langgenius/dify-web:0.14.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 719a025877e449..43e67a8db41254 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -107,6 +107,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60 # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 +APP_MAX_EXECUTION_TIME=1200 # ------------------------------ # Container Startup Related Configuration @@ -119,15 +120,15 @@ DIFY_BIND_ADDRESS=0.0.0.0 # API service binding port number, default 5001. DIFY_PORT=5001 -# The number of API server workers, i.e., the number of gevent workers. -# Formula: number of cpu cores x 2 + 1 +# The number of API server workers, i.e., the number of workers. +# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers SERVER_WORKER_AMOUNT= # Defaults to gevent. If using windows, it can be switched to sync or solo. SERVER_WORKER_CLASS= -# Similar to SERVER_WORKER_CLASS. Default is gevent. +# Similar to SERVER_WORKER_CLASS. # If using windows, it can be switched to sync or solo. CELERY_WORKER_CLASS= @@ -227,6 +228,7 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +REDIS_DB=0 # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -281,44 +283,42 @@ CONSOLE_CORS_ALLOW_ORIGINS=* # ------------------------------ # The type of storage to use for storing user files. -# Supported values are `local` , `s3` , `azure-blob` , `google-storage`, `tencent-cos`, `huawei-obs`, `volcengine-tos`, `baidu-obs`, `supabase` -# Default: `local` -STORAGE_TYPE=local -STORAGE_LOCAL_PATH=storage +STORAGE_TYPE=opendal + +# Apache OpenDAL Configuration +# The configuration for OpenDAL consists of the following format: OPENDAL__. +# You can find all the service configurations (CONFIG_NAME) in the repository at: https://github.com/apache/opendal/tree/main/core/src/services. +# Dify will scan configurations starting with OPENDAL_ and automatically apply them. +# The scheme name for the OpenDAL storage. +OPENDAL_SCHEME=fs +# Configurations for OpenDAL Local File System. +OPENDAL_FS_ROOT=storage # S3 Configuration -# Whether to use AWS managed IAM roles for authenticating with the S3 service. -# If set to false, the access key and secret key must be provided. -S3_USE_AWS_MANAGED_IAM=false -# The endpoint of the S3 service. +# S3_ENDPOINT= -# The region of the S3 service. S3_REGION=us-east-1 -# The name of the S3 bucket to use for storing files. S3_BUCKET_NAME=difyai -# The access key to use for authenticating with the S3 service. S3_ACCESS_KEY= -# The secret key to use for authenticating with the S3 service. S3_SECRET_KEY= +# Whether to use AWS managed IAM roles for authenticating with the S3 service. +# If set to false, the access key and secret key must be provided. +S3_USE_AWS_MANAGED_IAM=false # Azure Blob Configuration -# The name of the Azure Blob Storage account to use for storing files. +# AZURE_BLOB_ACCOUNT_NAME=difyai -# The access key to use for authenticating with the Azure Blob Storage account. AZURE_BLOB_ACCOUNT_KEY=difyai -# The name of the Azure Blob Storage container to use for storing files. AZURE_BLOB_CONTAINER_NAME=difyai-container -# The URL of the Azure Blob Storage account. AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Google Storage Configuration -# The name of the Google Storage bucket to use for storing files. +# GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name -# The service account JSON key to use for authenticating with the Google Storage service. GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # The Alibaba Cloud OSS configurations, -# only available when STORAGE_TYPE is `aliyun-oss` +# ALIYUN_OSS_BUCKET_NAME=your-bucket-name ALIYUN_OSS_ACCESS_KEY=your-access-key ALIYUN_OSS_SECRET_KEY=your-secret-key @@ -329,55 +329,47 @@ ALIYUN_OSS_AUTH_VERSION=v4 ALIYUN_OSS_PATH=your-path # Tencent COS Configuration -# The name of the Tencent COS bucket to use for storing files. +# TENCENT_COS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Tencent COS service. TENCENT_COS_SECRET_KEY=your-secret-key -# The secret id to use for authenticating with the Tencent COS service. TENCENT_COS_SECRET_ID=your-secret-id -# The region of the Tencent COS service. TENCENT_COS_REGION=your-region -# The scheme of the Tencent COS service. TENCENT_COS_SCHEME=your-scheme +# Oracle Storage Configuration +# +OCI_ENDPOINT=https://objectstorage.us-ashburn-1.oraclecloud.com +OCI_BUCKET_NAME=your-bucket-name +OCI_ACCESS_KEY=your-access-key +OCI_SECRET_KEY=your-secret-key +OCI_REGION=us-ashburn-1 + # Huawei OBS Configuration -# The name of the Huawei OBS bucket to use for storing files. +# HUAWEI_OBS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Huawei OBS service. HUAWEI_OBS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Huawei OBS service. HUAWEI_OBS_ACCESS_KEY=your-access-key -# The server url of the HUAWEI OBS service. HUAWEI_OBS_SERVER=your-server-url # Volcengine TOS Configuration -# The name of the Volcengine TOS bucket to use for storing files. +# VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Volcengine TOS service. VOLCENGINE_TOS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Volcengine TOS service. VOLCENGINE_TOS_ACCESS_KEY=your-access-key -# The endpoint of the Volcengine TOS service. VOLCENGINE_TOS_ENDPOINT=your-server-url -# The region of the Volcengine TOS service. VOLCENGINE_TOS_REGION=your-region # Baidu OBS Storage Configuration -# The name of the Baidu OBS bucket to use for storing files. +# BAIDU_OBS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Baidu OBS service. BAIDU_OBS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Baidu OBS service. BAIDU_OBS_ACCESS_KEY=your-access-key -# The endpoint of the Baidu OBS service. BAIDU_OBS_ENDPOINT=your-server-url # Supabase Storage Configuration -# The name of the Supabase bucket to use for storing files. +# SUPABASE_BUCKET_NAME=your-bucket-name -# The api key to use for authenticating with the Supabase service. SUPABASE_API_KEY=your-access-key -# The project endpoint url of the Supabase service. SUPABASE_URL=your-server-url # ------------------------------ @@ -390,28 +382,20 @@ VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT=http://weaviate:8080 -# The Weaviate API key. WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih # The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. QDRANT_URL=http://qdrant:6333 -# The Qdrant API key. QDRANT_API_KEY=difyai123456 -# The Qdrant client timeout setting. QDRANT_CLIENT_TIMEOUT=20 -# The Qdrant client enable gRPC mode. QDRANT_GRPC_ENABLED=false -# The Qdrant server gRPC mode PORT. QDRANT_GRPC_PORT=6334 # Milvus configuration Only available when VECTOR_STORE is `milvus`. # The milvus uri. MILVUS_URI=http://127.0.0.1:19530 -# The milvus token. MILVUS_TOKEN= -# The milvus username. MILVUS_USER=root -# The milvus password. MILVUS_PASSWORD=Milvus # MyScale configuration, only available when VECTOR_STORE is `myscale` @@ -465,8 +449,8 @@ ANALYTICDB_MAX_CONNECTION=5 # TiDB vector configurations, only available when VECTOR_STORE is `tidb` TIDB_VECTOR_HOST=tidb TIDB_VECTOR_PORT=4000 -TIDB_VECTOR_USER=xxx.root -TIDB_VECTOR_PASSWORD=xxxxxx +TIDB_VECTOR_USER= +TIDB_VECTOR_PASSWORD= TIDB_VECTOR_DATABASE=dify # Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` @@ -489,7 +473,7 @@ CHROMA_PORT=8000 CHROMA_TENANT=default_tenant CHROMA_DATABASE=default_database CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthClientProvider -CHROMA_AUTH_CREDENTIALS=xxxxxx +CHROMA_AUTH_CREDENTIALS= # Oracle configuration, only available when VECTOR_STORE is `oracle` ORACLE_HOST=oracle @@ -526,6 +510,7 @@ ELASTICSEARCH_HOST=0.0.0.0 ELASTICSEARCH_PORT=9200 ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic +KIBANA_PORT=5601 # baidu vector configurations, only available when VECTOR_STORE is `baidu` BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 @@ -545,11 +530,10 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 - # Lindorm configuration, only available when VECTOR_STORE is `lindorm` -LINDORM_URL=http://ld-***************-proxy-search-pub.lindorm.aliyuncs.com:30070 -LINDORM_USERNAME=username -LINDORM_PASSWORD=password +LINDORM_URL=http://lindorm:30070 +LINDORM_USERNAME=lindorm +LINDORM_PASSWORD=lindorm # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` OCEANBASE_VECTOR_HOST=oceanbase @@ -557,8 +541,13 @@ OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test +OCEANBASE_CLUSTER_NAME=difyai OCEANBASE_MEMORY_LIMIT=6G +# Upstash Vector configuration, only available when VECTOR_STORE is `upstash` +UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io +UPSTASH_VECTOR_TOKEN=dify + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -601,20 +590,16 @@ CODE_GENERATION_MAX_TOKENS=1024 # Multi-modal Configuration # ------------------------------ -# The format of the image/video sent when the multi-modal model is input, +# The format of the image/video/audio/document sent when the multi-modal model is input, # the default is base64, optional url. # The delay of the call in url mode will be lower than that in base64 mode. # It is generally recommended to use the more compatible base64 mode. -# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video. -MULTIMODAL_SEND_IMAGE_FORMAT=base64 -MULTIMODAL_SEND_VIDEO_FORMAT=base64 - +# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document. +MULTIMODAL_SEND_FORMAT=base64 # Upload image file size limit, default 10M. UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 - # Upload video file size limit, default 100M. UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 - # Upload audio file size limit, default 50M. UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 @@ -622,15 +607,14 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Sentry Configuration # Used for application monitoring and error log tracking. # ------------------------------ +SENTRY_DSN= # API Service Sentry DSN address, default is empty, when empty, # all monitoring information is not reported to Sentry. # If not set, Sentry error reporting will be disabled. API_SENTRY_DSN= - # API Service The reporting ratio of Sentry events, if it is 0.01, it is 1%. API_SENTRY_TRACES_SAMPLE_RATE=1.0 - # API Service The reporting ratio of Sentry profiles, if it is 0.01, it is 1%. API_SENTRY_PROFILES_SAMPLE_RATE=1.0 @@ -668,8 +652,10 @@ MAIL_TYPE=resend MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. +RESEND_API_URL=https://api.resend.com RESEND_API_KEY=your-resend-api-key + # SMTP server configuration, used when MAIL_TYPE is `smtp` SMTP_SERVER= SMTP_PORT=465 @@ -694,24 +680,26 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 # The sandbox service endpoint. CODE_EXECUTION_ENDPOINT=http://sandbox:8194 +CODE_EXECUTION_API_KEY=dify-sandbox CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 CODE_MAX_DEPTH=5 CODE_MAX_PRECISION=20 CODE_MAX_STRING_LENGTH=80000 -TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_EXECUTION_CONNECT_TIMEOUT=10 CODE_EXECUTION_READ_TIMEOUT=60 CODE_EXECUTION_WRITE_TIMEOUT=10 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 # Workflow runtime configuration WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration @@ -931,3 +919,7 @@ CSP_WHITELIST= # Enable or disable create tidb service job CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 + diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml new file mode 100644 index 00000000000000..d4e0ba49d0ba9a --- /dev/null +++ b/docker/docker-compose-template.yaml @@ -0,0 +1,576 @@ +x-shared-env: &shared-api-worker-env +services: + # API service + api: + image: langgenius/dify-api:0.14.2 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'api' starts the API server. + MODE: api + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + + # worker service + # The Celery worker for processing the queue. + worker: + image: langgenius/dify-api:0.14.2 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'worker' starts the Celery worker for processing the queue. + MODE: worker + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + + # Frontend web application. + web: + image: langgenius/dify-web:0.14.2 + restart: always + environment: + CONSOLE_API_URL: ${CONSOLE_API_URL:-} + APP_API_URL: ${APP_API_URL:-} + SENTRY_DSN: ${WEB_SENTRY_DSN:-} + NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + CSP_WHITELIST: ${CSP_WHITELIST:-} + + # The postgres database. + db: + image: postgres:15-alpine + restart: always + environment: + PGUSER: ${PGUSER:-postgres} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} + POSTGRES_DB: ${POSTGRES_DB:-dify} + PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + command: > + postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' + -c 'shared_buffers=${POSTGRES_SHARED_BUFFERS:-128MB}' + -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' + -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' + -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + volumes: + - ./volumes/db/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # The redis cache. + redis: + image: redis:6-alpine + restart: always + environment: + REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456} + volumes: + # Mount the redis data directory to the container. + - ./volumes/redis/data:/data + # Set the redis password when startup redis server. + command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} + healthcheck: + test: ['CMD', 'redis-cli', 'ping'] + + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:0.2.10 + restart: always + environment: + # The DifySandbox configurations + # Make sure you are changing this key for your deployment with a strong key. + # You can generate a strong key using `openssl rand -base64 42`. + API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} + GIN_MODE: ${SANDBOX_GIN_MODE:-release} + WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} + ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} + HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + volumes: + - ./volumes/sandbox/dependencies:/dependencies + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] + networks: + - ssrf_proxy_network + + # ssrf_proxy server + # for more information, please refer to + # https://docs.dify.ai/learn-more/faq/install-faq#id-18.-why-is-ssrf_proxy-needed + ssrf_proxy: + image: ubuntu/squid:latest + restart: always + volumes: + - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template + - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] + environment: + # pls clearly modify the squid env vars to fit your network environment. + HTTP_PORT: ${SSRF_HTTP_PORT:-3128} + COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} + REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} + SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + networks: + - ssrf_proxy_network + - default + + # Certbot service + # use `docker-compose --profile certbot up` to start the certbot service. + certbot: + image: certbot/certbot + profiles: + - certbot + volumes: + - ./volumes/certbot/conf:/etc/letsencrypt + - ./volumes/certbot/www:/var/www/html + - ./volumes/certbot/logs:/var/log/letsencrypt + - ./volumes/certbot/conf/live:/etc/letsencrypt/live + - ./certbot/update-cert.template.txt:/update-cert.template.txt + - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh + environment: + - CERTBOT_EMAIL=${CERTBOT_EMAIL} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} + entrypoint: ['/docker-entrypoint.sh'] + command: ['tail', '-f', '/dev/null'] + + # The nginx reverse proxy. + # used for reverse proxying the API service and Web service. + nginx: + image: nginx:latest + restart: always + volumes: + - ./nginx/nginx.conf.template:/etc/nginx/nginx.conf.template + - ./nginx/proxy.conf.template:/etc/nginx/proxy.conf.template + - ./nginx/https.conf.template:/etc/nginx/https.conf.template + - ./nginx/conf.d:/etc/nginx/conf.d + - ./nginx/docker-entrypoint.sh:/docker-entrypoint-mount.sh + - ./nginx/ssl:/etc/ssl # cert dir (legacy) + - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) + - ./volumes/certbot/conf:/etc/letsencrypt + - ./volumes/certbot/www:/var/www/html + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] + environment: + NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} + NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} + NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} + NGINX_PORT: ${NGINX_PORT:-80} + # You're required to add your own SSL certificates/keys to the `./nginx/ssl` directory + # and modify the env vars below in .env if HTTPS_ENABLED is true. + NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} + NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} + NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} + NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} + NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} + NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} + NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} + depends_on: + - api + - web + ports: + - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}' + - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}' + + # The TiDB vector store. + # For production use, please refer to https://github.com/pingcap/tidb-docker-compose + tidb: + image: pingcap/tidb:v8.4.0 + profiles: + - tidb + command: + - --store=unistore + restart: always + + # The Weaviate vector store. + weaviate: + image: semitechnologies/weaviate:1.19.0 + profiles: + - '' + - weaviate + restart: always + volumes: + # Mount the Weaviate data directory to the con tainer. + - ./volumes/weaviate:/var/lib/weaviate + environment: + # The Weaviate configurations + # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. + PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} + QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: ${WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED:-false} + DEFAULT_VECTORIZER_MODULE: ${WEAVIATE_DEFAULT_VECTORIZER_MODULE:-none} + CLUSTER_HOSTNAME: ${WEAVIATE_CLUSTER_HOSTNAME:-node1} + AUTHENTICATION_APIKEY_ENABLED: ${WEAVIATE_AUTHENTICATION_APIKEY_ENABLED:-true} + AUTHENTICATION_APIKEY_ALLOWED_KEYS: ${WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} + AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} + AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} + AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + + # Qdrant vector store. + # (if used, you need to set VECTOR_STORE to qdrant in the api & worker service.) + qdrant: + image: langgenius/qdrant:v1.7.3 + profiles: + - qdrant + restart: always + volumes: + - ./volumes/qdrant:/qdrant/storage + environment: + QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} + + # The Couchbase vector store. + couchbase-server: + build: ./couchbase-server + profiles: + - couchbase + restart: always + environment: + - CLUSTER_NAME=dify_search + - COUCHBASE_ADMINISTRATOR_USERNAME=${COUCHBASE_USER:-Administrator} + - COUCHBASE_ADMINISTRATOR_PASSWORD=${COUCHBASE_PASSWORD:-password} + - COUCHBASE_BUCKET=${COUCHBASE_BUCKET_NAME:-Embeddings} + - COUCHBASE_BUCKET_RAMSIZE=512 + - COUCHBASE_RAM_SIZE=2048 + - COUCHBASE_EVENTING_RAM_SIZE=512 + - COUCHBASE_INDEX_RAM_SIZE=512 + - COUCHBASE_FTS_RAM_SIZE=1024 + hostname: couchbase-server + container_name: couchbase-server + working_dir: /opt/couchbase + stdin_open: true + tty: true + entrypoint: [""] + command: sh -c "/opt/couchbase/init/init-cbserver.sh" + volumes: + - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data + healthcheck: + # ensure bucket was created before proceeding + test: [ "CMD-SHELL", "curl -s -f -u Administrator:password http://localhost:8091/pools/default/buckets | grep -q '\\[{' || exit 1" ] + interval: 10s + retries: 10 + start_period: 30s + timeout: 10s + + # The pgvector vector database. + pgvector: + image: pgvector/pgvector:pg16 + profiles: + - pgvector + restart: always + environment: + PGUSER: ${PGVECTOR_PGUSER:-postgres} + # The password for the default postgres user. + POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + # The name of the default postgres database. + POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + # postgres data directory + PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + volumes: + - ./volumes/pgvector/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # pgvecto-rs vector store + pgvecto-rs: + image: tensorchord/pgvecto-rs:pg16-v0.3.0 + profiles: + - pgvecto-rs + restart: always + environment: + PGUSER: ${PGVECTOR_PGUSER:-postgres} + # The password for the default postgres user. + POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + # The name of the default postgres database. + POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + # postgres data directory + PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + volumes: + - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # Chroma vector database + chroma: + image: ghcr.io/chroma-core/chroma:0.5.20 + profiles: + - chroma + restart: always + volumes: + - ./volumes/chroma:/chroma/chroma + environment: + CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} + CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} + IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + + # OceanBase vector database + oceanbase: + image: quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 + profiles: + - oceanbase + restart: always + volumes: + - ./volumes/oceanbase/data:/root/ob + - ./volumes/oceanbase/conf:/root/.obd/cluster + - ./volumes/oceanbase/init.d:/root/boot/init.d + environment: + OB_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + OB_SYS_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_TENANT_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} + OB_SERVER_IP: '127.0.0.1' + + # Oracle vector database + oracle: + image: container-registry.oracle.com/database/free:latest + profiles: + - oracle + restart: always + volumes: + - source: oradata + type: volume + target: /opt/oracle/oradata + - ./startupscripts:/opt/oracle/scripts/startup + environment: + ORACLE_PWD: ${ORACLE_PWD:-Dify123456} + ORACLE_CHARACTERSET: ${ORACLE_CHARACTERSET:-AL32UTF8} + + # Milvus vector database services + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + profiles: + - milvus + environment: + ETCD_AUTO_COMPACTION_MODE: ${ETCD_AUTO_COMPACTION_MODE:-revision} + ETCD_AUTO_COMPACTION_RETENTION: ${ETCD_AUTO_COMPACTION_RETENTION:-1000} + ETCD_QUOTA_BACKEND_BYTES: ${ETCD_QUOTA_BACKEND_BYTES:-4294967296} + ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} + volumes: + - ./volumes/milvus/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ['CMD', 'etcdctl', 'endpoint', 'health'] + interval: 30s + timeout: 20s + retries: 3 + networks: + - milvus + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + profiles: + - milvus + environment: + MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} + MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} + volumes: + - ./volumes/milvus/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + interval: 30s + timeout: 20s + retries: 3 + networks: + - milvus + + milvus-standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.3.1 + profiles: + - milvus + command: ['milvus', 'run', 'standalone'] + environment: + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} + common.security.authorizationEnabled: ${MILVUS_AUTHORIZATION_ENABLED:-true} + volumes: + - ./volumes/milvus/milvus:/var/lib/milvus + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + depends_on: + - etcd + - minio + ports: + - 19530:19530 + - 9091:9091 + networks: + - milvus + + # Opensearch vector database + opensearch: + container_name: opensearch + image: opensearchproject/opensearch:latest + profiles: + - opensearch + environment: + discovery.type: ${OPENSEARCH_DISCOVERY_TYPE:-single-node} + bootstrap.memory_lock: ${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} + OPENSEARCH_JAVA_OPTS: -Xms${OPENSEARCH_JAVA_OPTS_MIN:-512m} -Xmx${OPENSEARCH_JAVA_OPTS_MAX:-1024m} + OPENSEARCH_INITIAL_ADMIN_PASSWORD: ${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} + ulimits: + memlock: + soft: ${OPENSEARCH_MEMLOCK_SOFT:--1} + hard: ${OPENSEARCH_MEMLOCK_HARD:--1} + nofile: + soft: ${OPENSEARCH_NOFILE_SOFT:-65536} + hard: ${OPENSEARCH_NOFILE_HARD:-65536} + volumes: + - ./volumes/opensearch/data:/usr/share/opensearch/data + networks: + - opensearch-net + + opensearch-dashboards: + container_name: opensearch-dashboards + image: opensearchproject/opensearch-dashboards:latest + profiles: + - opensearch + environment: + OPENSEARCH_HOSTS: '["https://opensearch:9200"]' + volumes: + - ./volumes/opensearch/opensearch_dashboards.yml:/usr/share/opensearch-dashboards/config/opensearch_dashboards.yml + networks: + - opensearch-net + depends_on: + - opensearch + + # MyScale vector database + myscale: + container_name: myscale + image: myscale/myscaledb:1.6.4 + profiles: + - myscale + restart: always + tty: true + volumes: + - ./volumes/myscale/data:/var/lib/clickhouse + - ./volumes/myscale/log:/var/log/clickhouse-server + - ./volumes/myscale/config/users.d/custom_users_config.xml:/etc/clickhouse-server/users.d/custom_users_config.xml + ports: + - ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123} + + # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 + container_name: elasticsearch + profiles: + - elasticsearch + restart: always + volumes: + - dify_es01_data:/usr/share/elasticsearch/data + environment: + ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + cluster.name: dify-es-cluster + node.name: dify-es0 + discovery.type: single-node + xpack.license.self_generated.type: trial + xpack.security.enabled: 'true' + xpack.security.enrollment.enabled: 'false' + xpack.security.http.ssl.enabled: 'false' + ports: + - ${ELASTICSEARCH_PORT:-9200}:9200 + healthcheck: + test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] + interval: 30s + timeout: 10s + retries: 50 + + # https://www.elastic.co/guide/en/kibana/current/docker.html + # https://www.elastic.co/guide/en/kibana/current/settings.html + kibana: + image: docker.elastic.co/kibana/kibana:8.14.3 + container_name: kibana + profiles: + - elasticsearch + depends_on: + - elasticsearch + restart: always + environment: + XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY: d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa + NO_PROXY: localhost,127.0.0.1,elasticsearch,kibana + XPACK_SECURITY_ENABLED: 'true' + XPACK_SECURITY_ENROLLMENT_ENABLED: 'false' + XPACK_SECURITY_HTTP_SSL_ENABLED: 'false' + XPACK_FLEET_ISAIRGAPPED: 'true' + I18N_LOCALE: zh-CN + SERVER_PORT: '5601' + ELASTICSEARCH_HOSTS: http://elasticsearch:9200 + ports: + - ${KIBANA_PORT:-5601}:5601 + healthcheck: + test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] + interval: 30s + timeout: 10s + retries: 3 + + # unstructured . + # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) + unstructured: + image: downloads.unstructured.io/unstructured-io/unstructured-api:latest + profiles: + - unstructured + restart: always + volumes: + - ./volumes/unstructured:/app/data + +networks: + # create a network between sandbox, api and ssrf_proxy, and can not access outside. + ssrf_proxy_network: + driver: bridge + internal: true + milvus: + driver: bridge + opensearch-net: + driver: bridge + internal: true + +volumes: + oradata: + dify_es01_data: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 063813ad44ef1b..7122f4a6d0f768 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,28 +1,34 @@ +# ================================================================== +# WARNING: This file is auto-generated by generate_docker_compose +# Do not modify this file directly. Instead, update the .env.example +# or docker-compose-template.yaml and regenerate this file. +# ================================================================== + x-shared-env: &shared-api-worker-env - WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + CONSOLE_API_URL: ${CONSOLE_API_URL:-} + CONSOLE_WEB_URL: ${CONSOLE_WEB_URL:-} + SERVICE_API_URL: ${SERVICE_API_URL:-} + APP_API_URL: ${APP_API_URL:-} + APP_WEB_URL: ${APP_WEB_URL:-} + FILES_URL: ${FILES_URL:-} LOG_LEVEL: ${LOG_LEVEL:-INFO} - LOG_FILE: ${LOG_FILE:-} + LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - # Log dateformat - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} - # Log Timezone + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} - CONSOLE_WEB_URL: ${CONSOLE_WEB_URL:-} - CONSOLE_API_URL: ${CONSOLE_API_URL:-} - SERVICE_API_URL: ${SERVICE_API_URL:-} - APP_WEB_URL: ${APP_WEB_URL:-} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} - OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} - FILES_URL: ${FILES_URL:-} + DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} + OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} + ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} - MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} - DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} + APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_PORT: ${DIFY_PORT:-5001} SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} @@ -43,6 +49,11 @@ x-shared-env: &shared-api-worker-env SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} + POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100} + POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB} + POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} + POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} + POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} REDIS_USERNAME: ${REDIS_USERNAME:-} @@ -55,75 +66,73 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} - REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} + REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} - STORAGE_TYPE: ${STORAGE_TYPE:-local} - STORAGE_LOCAL_PATH: ${STORAGE_LOCAL_PATH:-storage} - S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + STORAGE_TYPE: ${STORAGE_TYPE:-opendal} + OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} + OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} S3_ENDPOINT: ${S3_ENDPOINT:-} - S3_BUCKET_NAME: ${S3_BUCKET_NAME:-} + S3_REGION: ${S3_REGION:-us-east-1} + S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} - S3_REGION: ${S3_REGION:-us-east-1} - AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-} - AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-} - AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-} - GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-} - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-} - ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-} - ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-} - ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} - ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} + S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} + AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} + AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} + ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} + ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} + ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} - ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-} - TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} - TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} - TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} - TENCENT_COS_REGION: ${TENCENT_COS_REGION:-} - TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-} - HUAWEI_OBS_BUCKET_NAME: ${HUAWEI_OBS_BUCKET_NAME:-} - HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-} - HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-} - HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-} - OCI_ENDPOINT: ${OCI_ENDPOINT:-} - OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-} - OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-} - OCI_SECRET_KEY: ${OCI_SECRET_KEY:-} - OCI_REGION: ${OCI_REGION:-} - VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-} - VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-} - VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-} - VOLCENGINE_TOS_ENDPOINT: ${VOLCENGINE_TOS_ENDPOINT:-} - VOLCENGINE_TOS_REGION: ${VOLCENGINE_TOS_REGION:-} - BAIDU_OBS_BUCKET_NAME: ${BAIDU_OBS_BUCKET_NAME:-} - BAIDU_OBS_SECRET_KEY: ${BAIDU_OBS_SECRET_KEY:-} - BAIDU_OBS_ACCESS_KEY: ${BAIDU_OBS_ACCESS_KEY:-} - BAIDU_OBS_ENDPOINT: ${BAIDU_OBS_ENDPOINT:-} + ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} + TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} + TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} + TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} + TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} + TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} + OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} + OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} + OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} + OCI_REGION: ${OCI_REGION:-us-ashburn-1} + HUAWEI_OBS_BUCKET_NAME: ${HUAWEI_OBS_BUCKET_NAME:-your-bucket-name} + HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-your-secret-key} + HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-your-access-key} + HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-your-server-url} + VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-your-bucket-name} + VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-your-secret-key} + VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-your-access-key} + VOLCENGINE_TOS_ENDPOINT: ${VOLCENGINE_TOS_ENDPOINT:-your-server-url} + VOLCENGINE_TOS_REGION: ${VOLCENGINE_TOS_REGION:-your-region} + BAIDU_OBS_BUCKET_NAME: ${BAIDU_OBS_BUCKET_NAME:-your-bucket-name} + BAIDU_OBS_SECRET_KEY: ${BAIDU_OBS_SECRET_KEY:-your-secret-key} + BAIDU_OBS_ACCESS_KEY: ${BAIDU_OBS_ACCESS_KEY:-your-access-key} + BAIDU_OBS_ENDPOINT: ${BAIDU_OBS_ENDPOINT:-your-server-url} + SUPABASE_BUCKET_NAME: ${SUPABASE_BUCKET_NAME:-your-bucket-name} + SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} + SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} + QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-'couchbase-server'} - COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} - COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} - COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} - COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} - MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} + MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} @@ -133,172 +142,264 @@ x-shared-env: &shared-api-worker-env MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - RELYT_HOST: ${RELYT_HOST:-db} - RELYT_PORT: ${RELYT_PORT:-5432} - RELYT_USER: ${RELYT_USER:-postgres} - RELYT_PASSWORD: ${RELYT_PASSWORD:-difyai123456} - RELYT_DATABASE: ${RELYT_DATABASE:-postgres} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} + COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} + COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} + COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} PGVECTOR_HOST: ${PGVECTOR_HOST:-pgvector} PGVECTOR_PORT: ${PGVECTOR_PORT:-5432} PGVECTOR_USER: ${PGVECTOR_USER:-postgres} PGVECTOR_PASSWORD: ${PGVECTOR_PASSWORD:-difyai123456} PGVECTOR_DATABASE: ${PGVECTOR_DATABASE:-dify} + PGVECTOR_MIN_CONNECTION: ${PGVECTOR_MIN_CONNECTION:-1} + PGVECTOR_MAX_CONNECTION: ${PGVECTOR_MAX_CONNECTION:-5} + PGVECTO_RS_HOST: ${PGVECTO_RS_HOST:-pgvecto-rs} + PGVECTO_RS_PORT: ${PGVECTO_RS_PORT:-5432} + PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres} + PGVECTO_RS_PASSWORD: ${PGVECTO_RS_PASSWORD:-difyai123456} + PGVECTO_RS_DATABASE: ${PGVECTO_RS_DATABASE:-dify} + ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-your-ak} + ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-your-sk} + ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-cn-hangzhou} + ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-gp-ab123456} + ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-testaccount} + ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-testpassword} + ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} + ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-difypassword} + ANALYTICDB_HOST: ${ANALYTICDB_HOST:-gp-test.aliyuncs.com} + ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} + ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} + ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} TIDB_VECTOR_HOST: ${TIDB_VECTOR_HOST:-tidb} TIDB_VECTOR_PORT: ${TIDB_VECTOR_PORT:-4000} TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} + TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} - ORACLE_HOST: ${ORACLE_HOST:-oracle} - ORACLE_PORT: ${ORACLE_PORT:-1521} - ORACLE_USER: ${ORACLE_USER:-dify} - ORACLE_PASSWORD: ${ORACLE_PASSWORD:-dify} - ORACLE_DATABASE: ${ORACLE_DATABASE:-FREEPDB1} CHROMA_HOST: ${CHROMA_HOST:-127.0.0.1} CHROMA_PORT: ${CHROMA_PORT:-8000} CHROMA_TENANT: ${CHROMA_TENANT:-default_tenant} CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} - ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} - ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} - ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} - ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} - LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} - LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } - KIBANA_PORT: ${KIBANA_PORT:-5601} - # AnalyticDB configuration - ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} - ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} - ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-} - ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-} - ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-} - ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} - ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} - ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} - ANALYTICDB_HOST: ${ANALYTICDB_HOST:-} - ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} - ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} - ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} + ORACLE_HOST: ${ORACLE_HOST:-oracle} + ORACLE_PORT: ${ORACLE_PORT:-1521} + ORACLE_USER: ${ORACLE_USER:-dify} + ORACLE_PASSWORD: ${ORACLE_PASSWORD:-dify} + ORACLE_DATABASE: ${ORACLE_DATABASE:-FREEPDB1} + RELYT_HOST: ${RELYT_HOST:-db} + RELYT_PORT: ${RELYT_PORT:-5432} + RELYT_USER: ${RELYT_USER:-postgres} + RELYT_PASSWORD: ${RELYT_PASSWORD:-difyai123456} + RELYT_DATABASE: ${RELYT_DATABASE:-postgres} OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} + ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} + ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} + ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + KIBANA_PORT: ${KIBANA_PORT:-5601} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} - VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-dify} - VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-dify} + VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} + VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} VIKINGDB_HOST: ${VIKINGDB_HOST:-api-vikingdb.xxx.volces.com} VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} + VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} + VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} + LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} + OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} + OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} + OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} + OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} + OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} + OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} + SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} - MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} - MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64} + MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} - SENTRY_DSN: ${API_SENTRY_DSN:-} - SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} - SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + SENTRY_DSN: ${SENTRY_DSN:-} + API_SENTRY_DSN: ${API_SENTRY_DSN:-} + API_SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + API_SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + WEB_SENTRY_DSN: ${WEB_SENTRY_DSN:-} NOTION_INTEGRATION_TYPE: ${NOTION_INTEGRATION_TYPE:-public} NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} MAIL_TYPE: ${MAIL_TYPE:-resend} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} + RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} SMTP_USERNAME: ${SMTP_USERNAME:-} SMTP_PASSWORD: ${SMTP_PASSWORD:-} SMTP_USE_TLS: ${SMTP_USE_TLS:-true} SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} - RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} - RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} - CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} - CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10} - CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60} - CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} - TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} CODE_MAX_OBJECT_ARRAY_LENGTH: ${CODE_MAX_OBJECT_ARRAY_LENGTH:-30} CODE_MAX_NUMBER_ARRAY_LENGTH: ${CODE_MAX_NUMBER_ARRAY_LENGTH:-1000} + CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10} + CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60} + CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10} + TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} WORKFLOW_MAX_EXECUTION_STEPS: ${WORKFLOW_MAX_EXECUTION_STEPS:-500} WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200} WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} + MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} + WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} + WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-12000} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + PGUSER: ${PGUSER:-${DB_USERNAME}} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} + POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} + PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} + SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} + SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} + SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} + WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} + WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: ${WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED:-true} + WEAVIATE_DEFAULT_VECTORIZER_MODULE: ${WEAVIATE_DEFAULT_VECTORIZER_MODULE:-none} + WEAVIATE_CLUSTER_HOSTNAME: ${WEAVIATE_CLUSTER_HOSTNAME:-node1} + WEAVIATE_AUTHENTICATION_APIKEY_ENABLED: ${WEAVIATE_AUTHENTICATION_APIKEY_ENABLED:-true} + WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS: ${WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} + WEAVIATE_AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} + WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} + WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} + CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} + CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + ORACLE_PWD: ${ORACLE_PWD:-Dify123456} + ORACLE_CHARACTERSET: ${ORACLE_CHARACTERSET:-AL32UTF8} + ETCD_AUTO_COMPACTION_MODE: ${ETCD_AUTO_COMPACTION_MODE:-revision} + ETCD_AUTO_COMPACTION_RETENTION: ${ETCD_AUTO_COMPACTION_RETENTION:-1000} + ETCD_QUOTA_BACKEND_BYTES: ${ETCD_QUOTA_BACKEND_BYTES:-4294967296} + ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} + MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} + MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} + MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} + PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} + PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + PGVECTOR_POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + PGVECTOR_PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + OPENSEARCH_DISCOVERY_TYPE: ${OPENSEARCH_DISCOVERY_TYPE:-single-node} + OPENSEARCH_BOOTSTRAP_MEMORY_LOCK: ${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} + OPENSEARCH_JAVA_OPTS_MIN: ${OPENSEARCH_JAVA_OPTS_MIN:-512m} + OPENSEARCH_JAVA_OPTS_MAX: ${OPENSEARCH_JAVA_OPTS_MAX:-1024m} + OPENSEARCH_INITIAL_ADMIN_PASSWORD: ${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} + OPENSEARCH_MEMLOCK_SOFT: ${OPENSEARCH_MEMLOCK_SOFT:--1} + OPENSEARCH_MEMLOCK_HARD: ${OPENSEARCH_MEMLOCK_HARD:--1} + OPENSEARCH_NOFILE_SOFT: ${OPENSEARCH_NOFILE_SOFT:-65536} + OPENSEARCH_NOFILE_HARD: ${OPENSEARCH_NOFILE_HARD:-65536} + NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} + NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} + NGINX_PORT: ${NGINX_PORT:-80} + NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} + NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} + NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} + NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} + NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} + NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} + NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} + NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + CERTBOT_EMAIL: ${CERTBOT_EMAIL:-your_email@example.com} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-your_domain.com} + CERTBOT_OPTIONS: ${CERTBOT_OPTIONS:-} + SSRF_HTTP_PORT: ${SSRF_HTTP_PORT:-3128} + SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} + SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} + SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} + COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} + EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} + EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} POSITION_TOOL_INCLUDES: ${POSITION_TOOL_INCLUDES:-} POSITION_TOOL_EXCLUDES: ${POSITION_TOOL_EXCLUDES:-} POSITION_PROVIDER_PINS: ${POSITION_PROVIDER_PINS:-} POSITION_PROVIDER_INCLUDES: ${POSITION_PROVIDER_INCLUDES:-} POSITION_PROVIDER_EXCLUDES: ${POSITION_PROVIDER_EXCLUDES:-} - MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} - OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-http://oceanbase-vector} - OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} - OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} - OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} - OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} - OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} - OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} - RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-0} + MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} services: # API service api: - image: langgenius/dify-api:0.13.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env # Startup mode, 'api' starts the API server. MODE: api + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} depends_on: - db - redis @@ -312,13 +413,16 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.13.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env # Startup mode, 'worker' starts the Celery worker for processing the queue. MODE: worker + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} depends_on: - db - redis @@ -331,7 +435,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.13.1 + image: langgenius/dify-web:0.14.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -491,6 +595,16 @@ services: - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}' - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}' + # The TiDB vector store. + # For production use, please refer to https://github.com/pingcap/tidb-docker-compose + tidb: + image: pingcap/tidb:v8.4.0 + profiles: + - tidb + command: + - --store=unistore + restart: always + # The Weaviate vector store. weaviate: image: semitechnologies/weaviate:1.19.0 diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose new file mode 100755 index 00000000000000..54b6d55217f8ba --- /dev/null +++ b/docker/generate_docker_compose @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +import os +import re +import sys + + +def parse_env_example(file_path): + """ + Parses the .env.example file and returns a dictionary with variable names as keys and default values as values. + """ + env_vars = {} + with open(file_path, "r") as f: + for line_number, line in enumerate(f, 1): + line = line.strip() + # Ignore empty lines and comments + if not line or line.startswith("#"): + continue + # Use regex to parse KEY=VALUE + match = re.match(r"^([^=]+)=(.*)$", line) + if match: + key = match.group(1).strip() + value = match.group(2).strip() + # Remove possible quotes around the value + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + value = value[1:-1] + env_vars[key] = value + else: + print(f"Warning: Unable to parse line {line_number}: {line}") + return env_vars + + +def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): + """ + Generates a shared environment variables block as a YAML string. + """ + lines = [f"x-shared-env: &{anchor_name}"] + for key, default in env_vars.items(): + # If default value is empty, use ${KEY:-} + if default == "": + lines.append(f" {key}: ${{{key}:-}}") + else: + # If default value contains special characters, wrap it in quotes + if re.search(r"[:\s]", default): + default = f'"{default}"' + lines.append(f" {key}: ${{{key}:-{default}}}") + return "\n".join(lines) + + +def insert_shared_env(template_path, output_path, shared_env_block, header_comments): + """ + Inserts the shared environment variables block and header comments into the template file, + removing any existing x-shared-env anchors, and generates the final docker-compose.yaml file. + """ + with open(template_path, "r") as f: + template_content = f.read() + + # Remove existing x-shared-env: &shared-api-worker-env lines + template_content = re.sub( + r"^x-shared-env: &shared-api-worker-env\s*\n?", + "", + template_content, + flags=re.MULTILINE, + ) + + # Prepare the final content with header comments and shared env block + final_content = f"{header_comments}\n{shared_env_block}\n\n{template_content}" + + with open(output_path, "w") as f: + f.write(final_content) + print(f"Generated {output_path}") + + +def main(): + env_example_path = ".env.example" + template_path = "docker-compose-template.yaml" + output_path = "docker-compose.yaml" + anchor_name = "shared-api-worker-env" # Can be modified as needed + + # Define header comments to be added at the top of docker-compose.yaml + header_comments = ( + "# ==================================================================\n" + "# WARNING: This file is auto-generated by generate_docker_compose\n" + "# Do not modify this file directly. Instead, update the .env.example\n" + "# or docker-compose-template.yaml and regenerate this file.\n" + "# ==================================================================\n" + ) + + # Check if required files exist + for path in [env_example_path, template_path]: + if not os.path.isfile(path): + print(f"Error: File {path} does not exist.") + sys.exit(1) + + # Parse .env.example file + env_vars = parse_env_example(env_example_path) + + if not env_vars: + print("Warning: No environment variables found in .env.example.") + + # Generate shared environment variables block + shared_env_block = generate_shared_env_block(env_vars, anchor_name) + + # Insert shared environment variables block and header comments into the template + insert_shared_env(template_path, output_path, shared_env_block, header_comments) + + +if __name__ == "__main__": + main() diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index e6644883018769..ee1b5c57e1d1d0 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -160,7 +160,10 @@ def get_result(self, workflow_run_id): class KnowledgeBaseClient(DifyClient): def __init__( - self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None + self, + api_key, + base_url: str = "https://api.dify.ai/v1", + dataset_id: str | None = None, ): """ Construct a KnowledgeBaseClient object. @@ -187,7 +190,9 @@ def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + def create_document_by_text( + self, name, text, extra_params: dict | None = None, **kwargs + ): """ Create a document by text. @@ -225,7 +230,7 @@ def create_document_by_text(self, name, text, extra_params: dict = None, **kwarg return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict = None, **kwargs + self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -262,7 +267,7 @@ def update_document_by_text( return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict = None + self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. @@ -304,7 +309,7 @@ def create_document_by_file( ) def update_document_by_file( - self, document_id, file_path, extra_params: dict = None + self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. @@ -372,7 +377,11 @@ def delete_document(self, document_id): return self._send_request("DELETE", url) def list_documents( - self, page: int = None, page_size: int = None, keyword: str = None, **kwargs + self, + page: int | None = None, + page_size: int | None = None, + keyword: str | None = None, + **kwargs, ): """ Get a list of documents in this dataset. @@ -402,7 +411,11 @@ def add_segments(self, document_id, segments, **kwargs): return self._send_request("POST", url, json=data, **kwargs) def query_segments( - self, document_id, keyword: str = None, status: str = None, **kwargs + self, + document_id, + keyword: str | None = None, + status: str | None = None, + **kwargs, ): """ Query segments in this document. diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 96ee874d53caff..1d963203098fa5 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -25,6 +25,7 @@ import { fetchAppDetail, fetchAppSSO } from '@/service/apps' import AppContext, { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import type { App } from '@/types/app' export type IAppDetailLayoutProps = { children: React.ReactNode @@ -41,12 +42,14 @@ const AppDetailLayout: FC = (props) => { const pathname = usePathname() const media = useBreakpoints() const isMobile = media === MediaType.mobile - const { isCurrentWorkspaceEditor } = useAppContext() + const { isCurrentWorkspaceEditor, isLoadingCurrentWorkspace } = useAppContext() const { appDetail, setAppDetail, setAppSiderbarExpand } = useStore(useShallow(state => ({ appDetail: state.appDetail, setAppDetail: state.setAppDetail, setAppSiderbarExpand: state.setAppSiderbarExpand, }))) + const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) + const [appDetailRes, setAppDetailRes] = useState(null) const [navigation, setNavigation] = useState = (props) => { useEffect(() => { setAppDetail() + setIsLoadingAppDetail(true) fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - // redirection - const canIEditApp = isCurrentWorkspaceEditor - if (!canIEditApp && (pathname.endsWith('configuration') || pathname.endsWith('workflow') || pathname.endsWith('logs'))) { - router.replace(`/app/${appId}/overview`) - return - } - if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { - router.replace(`/app/${appId}/workflow`) - } - else if ((res.mode !== 'workflow' && res.mode !== 'advanced-chat') && (pathname).endsWith('workflow')) { - router.replace(`/app/${appId}/configuration`) - } - else { - setAppDetail({ ...res, enable_sso: false }) - setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) - if (systemFeatures.enable_web_sso_switch_component && canIEditApp) { - fetchAppSSO({ appId }).then((ssoRes) => { - setAppDetail({ ...res, enable_sso: ssoRes.enabled }) - }) - } - } + setAppDetailRes(res) }).catch((e: any) => { if (e.status === 404) router.replace('/apps') + }).finally(() => { + setIsLoadingAppDetail(false) }) - }, [appId, isCurrentWorkspaceEditor, systemFeatures, getNavigations, pathname, router, setAppDetail]) + }, [appId, router, setAppDetail]) + + useEffect(() => { + if (!appDetailRes || isLoadingCurrentWorkspace || isLoadingAppDetail) + return + const res = appDetailRes + // redirection + const canIEditApp = isCurrentWorkspaceEditor + if (!canIEditApp && (pathname.endsWith('configuration') || pathname.endsWith('workflow') || pathname.endsWith('logs'))) { + router.replace(`/app/${appId}/overview`) + return + } + if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { + router.replace(`/app/${appId}/workflow`) + } + else if ((res.mode !== 'workflow' && res.mode !== 'advanced-chat') && (pathname).endsWith('workflow')) { + router.replace(`/app/${appId}/configuration`) + } + else { + setAppDetail({ ...res, enable_sso: false }) + setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + if (systemFeatures.enable_web_sso_switch_component && canIEditApp) { + fetchAppSSO({ appId }).then((ssoRes) => { + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + }) + } + } + }, [appDetailRes, appId, getNavigations, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, pathname, router, setAppDetail, systemFeatures.enable_web_sso_switch_component]) useUnmount(() => { setAppDetail() @@ -141,7 +154,7 @@ const AppDetailLayout: FC = (props) => { if (!appDetail) { return ( -
+
) @@ -152,7 +165,7 @@ const AppDetailLayout: FC = (props) => { {appDetail && ( )} -
+
{children}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index b5d3462dfacf47..bb1e4fd95bd10e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -28,7 +28,7 @@ export default function ChartView({ appId }: IChartViewProps) { const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }) const onSelect = (item: Item) => { - if (item.value === '-1') { + if (item.value === -1) { setPeriod({ name: item.name, query: undefined }) } else if (item.value === 0) { diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 1ffb132cf8c186..dabe75ee625a7e 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -9,7 +9,7 @@ import s from './style.module.css' import cn from '@/utils/classnames' import type { App } from '@/types/app' import Confirm from '@/app/components/base/confirm' -import { ToastContext } from '@/app/components/base/toast' +import Toast, { ToastContext } from '@/app/components/base/toast' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import DuplicateAppModal from '@/app/components/app/duplicate-modal' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' @@ -21,8 +21,6 @@ import Divider from '@/app/components/base/divider' import { getRedirection } from '@/utils/app-redirection' import { useProviderContext } from '@/context/provider-context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' -import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import EditAppModal from '@/app/components/explore/create-app-modal' import SwitchAppModal from '@/app/components/app/switch-app-modal' @@ -31,6 +29,8 @@ import TagSelector from '@/app/components/base/tag-management/selector' import type { EnvironmentVariable } from '@/app/components/workflow/types' import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal' import { fetchWorkflowDraft } from '@/service/workflow' +import { fetchInstalledAppList } from '@/service/explore' +import { AppTypeIcon } from '@/app/components/app/type-selector' export type AppCardProps = { app: App @@ -209,6 +209,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.preventDefault() setShowConfirmDelete(true) } + const onClickInstalledApp = async (e: React.MouseEvent) => { + e.stopPropagation() + props.onClick?.() + e.preventDefault() + try { + const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} + if (installed_apps?.length > 0) + window.open(`/explore/installed/${installed_apps[0].id}`, '_blank') + else + throw new Error('No app found in Explore') + } + catch (e: any) { + Toast.notify({ type: 'error', message: `${e.message || e}` }) + } + } return (
+
{ e.preventDefault() getRedirection(isCurrentWorkspaceEditor, app, push) }} - className='relative group col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' + className='relative h-[160px] group col-span-1 bg-components-card-bg border-[1px] border-solid border-components-card-border rounded-xl shadow-sm inline-flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' >
@@ -268,30 +287,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { background={app.icon_background} imageUrl={app.icon_url} /> - - {app.mode === 'advanced-chat' && ( - - )} - {app.mode === 'agent-chat' && ( - - )} - {app.mode === 'chat' && ( - - )} - {app.mode === 'completion' && ( - - )} - {app.mode === 'workflow' && ( - - )} - +
-
+
{app.name}
-
- {app.mode === 'advanced-chat' &&
{t('app.types.chatbot').toUpperCase()}
} +
+ {app.mode === 'advanced-chat' &&
{t('app.types.advanced').toUpperCase()}
} {app.mode === 'chat' &&
{t('app.types.chatbot').toUpperCase()}
} {app.mode === 'agent-chat' &&
{t('app.types.agent').toUpperCase()}
} {app.mode === 'workflow' &&
{t('app.types.workflow').toUpperCase()}
} @@ -299,7 +302,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
-
+
{ />
-
+
} @@ -342,7 +345,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
- +
} btnClassName={open => @@ -353,10 +356,10 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } popupClassName={ (app.mode === 'completion' || app.mode === 'chat') - ? '!w-[238px] translate-x-[-110px]' - : '' + ? '!w-[256px] translate-x-[-224px]' + : '!w-[160px] translate-x-[-128px]' } - className={'!w-[128px] h-fit !z-20'} + className={'h-fit !z-20'} />
diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index 9d6345aa6c3de1..5269571c210aa2 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -125,7 +125,7 @@ const Apps = () => { return ( <> -
+
{ />
- + {(data && data[0].total > 0) + ?
+ {isCurrentWorkspaceEditor + && } + {data.map(({ data: apps }) => apps.map(app => ( + + )))} +
+ :
+ {isCurrentWorkspaceEditor + && } + +
} +
{showTagManagementModal && ( @@ -160,3 +166,21 @@ const Apps = () => { } export default Apps + +function NoAppsFound() { + const { t } = useTranslation() + function renderDefaultCard() { + const defaultCards = Array.from({ length: 36 }, (_, index) => ( +
+ )) + return defaultCards + } + return ( + <> + {renderDefaultCard()} +
+ {t('app.newApp.noAppsFound')} +
+ + ) +} diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx index c0dffa99abe411..a90af4ea85caf4 100644 --- a/web/app/(commonLayout)/apps/NewAppCard.tsx +++ b/web/app/(commonLayout)/apps/NewAppCard.tsx @@ -11,13 +11,14 @@ import CreateAppModal from '@/app/components/app/create-app-modal' import CreateFromDSLModal, { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal' import { useProviderContext } from '@/context/provider-context' import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import cn from '@/utils/classnames' export type CreateAppCardProps = { + className?: string onSuccess?: () => void } -// eslint-disable-next-line react/display-name -const CreateAppCard = forwardRef(({ onSuccess }, ref) => { +const CreateAppCard = forwardRef(({ className, onSuccess }, ref) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() const searchParams = useSearchParams() @@ -36,30 +37,28 @@ const CreateAppCard = forwardRef(({ onSuc }, [dslUrl]) return ( -
-
{t('app.createApp')}
-
setShowNewAppModal(true)}> +
{t('app.createApp')}
+
-
setShowNewAppTemplateDialog(true)}> + +
-
-
setShowCreateFromDSLModal(true)} - > -
+ +
+
+ setShowNewAppModal(false)} @@ -68,6 +67,10 @@ const CreateAppCard = forwardRef(({ onSuc if (onSuccess) onSuccess() }} + onCreateFromTemplate={() => { + setShowNewAppTemplateDialog(true) + setShowNewAppModal(false) + }} /> (({ onSuc if (onSuccess) onSuccess() }} + onCreateFromBlank={() => { + setShowNewAppModal(true) + setShowNewAppTemplateDialog(false) + }} /> (({ onSuc onSuccess() }} /> -
+
) }) +CreateAppCard.displayName = 'CreateAppCard' export default CreateAppCard +export { CreateAppCard } diff --git a/web/app/(commonLayout)/apps/page.tsx b/web/app/(commonLayout)/apps/page.tsx index ab9852e46275af..972aabc8bc5989 100644 --- a/web/app/(commonLayout)/apps/page.tsx +++ b/web/app/(commonLayout)/apps/page.tsx @@ -1,9 +1,10 @@ 'use client' import { useContextSelector } from 'use-context-selector' import { useTranslation } from 'react-i18next' +import { RiDiscordFill, RiGithubFill } from '@remixicon/react' +import Link from 'next/link' import style from '../list.module.css' import Apps from './Apps' -import classNames from '@/utils/classnames' import AppContext from '@/context/app-context' import { LicenseStatus } from '@/types/feature' @@ -12,14 +13,18 @@ const AppList = () => { const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures) return ( -
+
{systemFeatures.license.status === LicenseStatus.NONE &&

{t('app.join')}

-

{t('app.communityIntro')}

+

{t('app.communityIntro')}

- - + + + + + +
}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index a58027bcd12e02..b416659a6a1cfa 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -166,7 +166,7 @@ const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer' href={ locale === LanguagesSupported[1] - ? 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate_knowledge_within_application' + ? 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate-knowledge-within-application' : 'https://docs.dify.ai/guides/knowledge-base/integrate-knowledge-within-application' } target='_blank' rel='noopener noreferrer' diff --git a/web/app/(commonLayout)/list.module.css b/web/app/(commonLayout)/list.module.css index bb2aa8606c38d8..2fc6469a6dd982 100644 --- a/web/app/(commonLayout)/list.module.css +++ b/web/app/(commonLayout)/list.module.css @@ -201,14 +201,6 @@ @apply block w-6 h-6 bg-center bg-contain; } -.githubIcon { - background-image: url("./apps/assets/github.svg"); -} - -.discordIcon { - background-image: url("./apps/assets/discord.svg"); -} - /* #region new app dialog */ .newItemCaption { @apply inline-flex items-center mb-2 text-sm font-medium; diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index 71540ce3b1265a..c7af05793f296b 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -18,10 +18,10 @@ import { IS_CE_EDITION } from '@/config' import Input from '@/app/components/base/input' const titleClassName = ` - text-sm font-medium text-gray-900 + system-sm-semibold text-text-secondary ` const descriptionClassName = ` - mt-1 text-xs font-normal text-gray-500 + mt-1 body-xs-regular text-text-tertiary ` const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ @@ -122,7 +122,7 @@ export default function AccountPage() {
-
{item.name}
+
{item.name}
) } @@ -130,7 +130,7 @@ export default function AccountPage() { return ( <>
-

{t('common.account.myAccount')}

+

{t('common.account.myAccount')}

@@ -142,10 +142,10 @@ export default function AccountPage() {
{t('common.account.name')}
-
+
{userProfile.name}
-
+
{t('common.operation.edit')}
@@ -153,7 +153,7 @@ export default function AccountPage() {
{t('common.account.email')}
-
+
{userProfile.email}
@@ -162,14 +162,14 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('common.account.password')}
-
{t('common.account.passwordTip')}
+
{t('common.account.password')}
+
{t('common.account.passwordTip')}
) } -
+
{t('common.account.langGeniusAccount')}
{t('common.account.langGeniusAccountTip')}
@@ -181,7 +181,7 @@ export default function AccountPage() { wrapperClassName='mt-2' /> )} - {!IS_CE_EDITION && } + {!IS_CE_EDITION && }
{ editNameModalVisible && ( @@ -190,7 +190,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className={s.modal} > -
{t('common.account.editName')}
+
{t('common.account.editName')}
{t('common.account.name')}
-
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
{userProfile.is_password_set && ( <>
{t('common.account.currentPassword')}
@@ -242,7 +242,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')}
@@ -261,7 +261,7 @@ export default function AccountPage() {
-
{t('common.account.confirmPassword')}
+
{t('common.account.confirmPassword')}
-
+
{t('common.account.deleteTip')}
{t('common.account.deleteConfirmTip')}
-
{`${t('common.account.delete')}: ${userProfile.email}`}
+
{`${t('common.account.delete')}: ${userProfile.email}`}
} confirmText={t('common.operation.ok') as string} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 544e43ab27f99f..8fdecc07bf867b 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -40,9 +40,9 @@ export default function AppSelector() { className={` inline-flex items-center rounded-[20px] p-1x text-sm - text-gray-700 hover:bg-gray-200 + text-text-primary mobile:px-1 - ${open && 'bg-gray-200'} + ${open && 'bg-components-panel-bg-blur'} `} > @@ -60,7 +60,7 @@ export default function AppSelector() { @@ -78,10 +78,10 @@ export default function AppSelector() {
handleLogout()}>
- -
{t('common.userProfile.logout')}
+ +
{t('common.userProfile.logout')}
diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx index 5aa8b05cbfd07b..11a6abeab40782 100644 --- a/web/app/account/layout.tsx +++ b/web/app/account/layout.tsx @@ -21,7 +21,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
-
+
{children}
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 12fe5cba468df3..12f9c59cd16251 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -237,7 +237,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {appDetail.mode === 'advanced-chat' && ( <>
{t('app.types.chatbot').toUpperCase()}
-
{t('app.newApp.advanced').toUpperCase()}
+
{t('app.types.advanced').toUpperCase()}
)} {appDetail.mode === 'agent-chat' && ( @@ -246,13 +246,13 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {appDetail.mode === 'chat' && ( <>
{t('app.types.chatbot').toUpperCase()}
-
{(t('app.newApp.basic').toUpperCase())}
+
{(t('app.types.basic').toUpperCase())}
)} {appDetail.mode === 'completion' && ( <>
{t('app.types.completion').toUpperCase()}
-
{(t('app.newApp.basic').toUpperCase())}
+
{(t('app.types.basic').toUpperCase())}
)} {appDetail.mode === 'workflow' && ( @@ -299,7 +299,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {appDetail.mode === 'advanced-chat' && ( <>
{t('app.types.chatbot').toUpperCase()}
-
{t('app.newApp.advanced').toUpperCase()}
+
{t('app.types.advanced').toUpperCase()}
)} {appDetail.mode === 'agent-chat' && ( @@ -308,13 +308,13 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {appDetail.mode === 'chat' && ( <>
{t('app.types.chatbot').toUpperCase()}
-
{(t('app.newApp.basic').toUpperCase())}
+
{(t('app.types.basic').toUpperCase())}
)} {appDetail.mode === 'completion' && ( <>
{t('app.types.completion').toUpperCase()}
-
{(t('app.newApp.basic').toUpperCase())}
+
{(t('app.types.basic').toUpperCase())}
)} {appDetail.mode === 'workflow' && ( @@ -398,7 +398,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { )} />
- {showSwitchTip === 'chat' ? t('app.newApp.advanced') : t('app.types.workflow')} + {showSwitchTip === 'chat' ? t('app.types.advanced') : t('app.types.workflow')} BETA
{t('app.newApp.advancedFor').toLocaleUpperCase()}
diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx index 4da6b7cac4d0b4..032e4b83576adf 100644 --- a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import Textarea from 'rc-textarea' +import Textarea from '@/app/components/base/textarea' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' export enum EditItemType { @@ -31,12 +31,10 @@ const EditItem: FC = ({ {avatar}
-
{name}
+
{name}