From ea584e94bd5fc30aed78a506261c7ae987e2e9f1 Mon Sep 17 00:00:00 2001 From: Likename Haojie Date: Fri, 11 Oct 2024 22:46:44 +0800 Subject: [PATCH 01/25] fix: dialog box cannot correctly display LaTeX formulas (#9242) --- web/app/components/base/markdown.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 39a399cc9f73c9..dbe408788253b7 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -245,7 +245,7 @@ export function Markdown(props: { content: string; className?: string }) { return (
Date: Fri, 11 Oct 2024 22:48:57 +0800 Subject: [PATCH 02/25] feat: add supabase object storage (#9229) --- api/.env.example | 7 +- api/configs/middleware/__init__.py | 2 + .../storage/supabase_storage_config.py | 24 ++ api/extensions/ext_storage.py | 4 + api/extensions/storage/storage_type.py | 1 + api/extensions/storage/supabase_storage.py | 60 ++++ api/poetry.lock | 324 ++++++++++++------ api/pyproject.toml | 1 + 8 files changed, 318 insertions(+), 105 deletions(-) create mode 100644 api/configs/middleware/storage/supabase_storage_config.py create mode 100644 api/extensions/storage/supabase_storage.py diff --git a/api/.env.example b/api/.env.example index 71f0e5db8f8b9b..7b5e4950c82ddd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -39,7 +39,7 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs +# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase STORAGE_TYPE=local STORAGE_LOCAL_PATH=storage S3_USE_AWS_MANAGED_IAM=false @@ -99,6 +99,11 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key VOLCENGINE_TOS_SECRET_KEY=your-secret-key VOLCENGINE_TOS_REGION=your-region +# Supabase Storage Configuration +SUPABASE_BUCKET_NAME=your-bucket-name +SUPABASE_API_KEY=your-access-key +SUPABASE_URL=your-server-url + # CORS configuration WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 5fec991d6e2d97..25f3df6dde41d7 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -12,6 +12,7 @@ from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig +from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig @@ -222,6 +223,7 @@ class MiddlewareConfig( HuaweiCloudOBSStorageConfig, OCIStorageConfig, S3StorageConfig, + SupabaseStorageConfig, TencentCloudCOSStorageConfig, VolcengineTOSStorageConfig, # configs of vdb and vdb providers diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py new file mode 100644 index 00000000000000..a3e905b21c63e9 --- /dev/null +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -0,0 +1,24 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class SupabaseStorageConfig(BaseModel): + """ + Configuration settings for Supabase Object Storage Service + """ + + SUPABASE_BUCKET_NAME: Optional[str] = Field( + description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", + default=None, + ) + + SUPABASE_API_KEY: Optional[str] = Field( + description="API KEY for authenticating with Supabase", + default=None, + ) + + SUPABASE_URL: Optional[str] = Field( + description="URL of the Supabase", + default=None, + ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index be57b633bed4a3..f90629262d89da 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -56,6 +56,10 @@ def get_storage_factory(storage_type: str) -> type[BaseStorage]: from extensions.storage.volcengine_tos_storage import VolcengineTosStorage return VolcengineTosStorage + case StorageType.SUPBASE: + from extensions.storage.supabase_storage import SupabaseStorage + + return SupabaseStorage case StorageType.LOCAL | _: from extensions.storage.local_fs_storage import LocalFsStorage diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index e494a520a20b01..415bf251f6e280 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -12,3 +12,4 @@ class StorageType(str, Enum): S3 = "s3" TENCENT_COS = "tencent-cos" VOLCENGINE_TOS = "volcengine-tos" + SUPBASE = "supabase" diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py new file mode 100644 index 00000000000000..1e399f87c8ba37 --- /dev/null +++ b/api/extensions/storage/supabase_storage.py @@ -0,0 +1,60 @@ +import io +from collections.abc import Generator +from pathlib import Path + +from flask import Flask +from supabase import Client + +from extensions.storage.base_storage import BaseStorage + + +class SupabaseStorage(BaseStorage): + """Implementation for supabase obs storage.""" + + def __init__(self, app: Flask): + super().__init__(app) + app_config = self.app.config + self.bucket_name = app_config.get("SUPABASE_BUCKET_NAME") + self.client = Client( + supabase_url=app_config.get("SUPABASE_URL"), supabase_key=app_config.get("SUPABASE_API_KEY") + ) + self.create_bucket( + id=app_config.get("SUPABASE_BUCKET_NAME"), bucket_name=app_config.get("SUPABASE_BUCKET_NAME") + ) + + def create_bucket(self, id, bucket_name): + if not self.bucket_exists(): + self.client.storage.create_bucket(id=id, name=bucket_name) + + def save(self, filename, data): + self.client.storage.from_(self.bucket_name).upload(filename, data) + + def load_once(self, filename: str) -> bytes: + content = self.client.storage.from_(self.bucket_name).download(filename) + return content + + def load_stream(self, filename: str) -> Generator: + def generate(filename: str = filename) -> Generator: + result = self.client.storage.from_(self.bucket_name).download(filename) + byte_stream = io.BytesIO(result) + while chunk := byte_stream.read(4096): # Read in chunks of 4KB + yield chunk + + return generate() + + def download(self, filename, target_filepath): + result = self.client.storage.from_(self.bucket_name).download(filename) + Path(result).write_bytes(result) + + def exists(self, filename): + result = self.client.storage.from_(self.bucket_name).list(filename) + if result.count() > 0: + return True + return False + + def delete(self, filename): + self.client.storage.from_(self.bucket_name).remove(filename) + + def bucket_exists(self): + buckets = self.client.storage.list_buckets() + return any(bucket.name == self.bucket_name for bucket in buckets) diff --git a/api/poetry.lock b/api/poetry.lock index b7421ca566929a..fd6b98191df02c 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -844,13 +844,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.37" +version = "1.35.38" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.37-py3-none-any.whl", hash = "sha256:64f965d4ba7adb8d79ce044c3aef7356e05dd74753cf7e9115b80f477845d920"}, - {file = "botocore-1.35.37.tar.gz", hash = "sha256:b2b4d29bafd95b698344f2f0577bb67064adbf1735d8a0e3c7473daa59c23ba6"}, + {file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"}, + {file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"}, ] [package.dependencies] @@ -2066,6 +2066,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "dill" version = "0.3.9" @@ -3366,6 +3380,21 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "gotrue" +version = "2.9.2" +description = "Python Client Library for Supabase Auth" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"}, + {file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +pydantic = ">=1.10,<3" + [[package]] name = "greenlet" version = "3.1.1" @@ -4408,13 +4437,13 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.133" +version = "0.1.134" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.133-py3-none-any.whl", hash = "sha256:82e837a6039c483beadbe19c2ba7ebafbd402d3e8105234f5ef334425cff7b45"}, - {file = "langsmith-0.1.133.tar.gz", hash = "sha256:7bfd8bef166b9a64ee540a11bee4aa7bf43b1d9229f95b0fc19086454955185d"}, + {file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"}, + {file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"}, ] [package.dependencies] @@ -6395,6 +6424,23 @@ docs = ["sphinx (>=1.7.1)"] redis = ["redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] +[[package]] +name = "postgrest" +version = "0.17.1" +description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"}, + {file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"}, +] + +[package.dependencies] +deprecation = ">=2.1.0,<3.0.0" +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +pydantic = ">=1.9,<3.0" +strenum = ">=0.4.9,<0.5.0" + [[package]] name = "posthog" version = "3.7.0" @@ -7695,6 +7741,23 @@ dev = ["coveralls", "m2r", "pycodestyle", "pyflakes", "pylint", "pytest", "pytes docs = ["m2r", "sphinx"] test = ["coveralls", "pycodestyle", "pyflakes", "pylint", "pytest", "pytest-benchmark", "pytest-cov"] +[[package]] +name = "realtime" +version = "2.0.2" +description = "" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, + {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, +] + +[package.dependencies] +aiohttp = ">=3.10.2,<4.0.0" +python-dateutil = ">=2.8.1,<3.0.0" +typing-extensions = ">=4.12.2,<5.0.0" +websockets = ">=11,<13" + [[package]] name = "redis" version = "5.0.8" @@ -8578,19 +8641,20 @@ files = [ [[package]] name = "simple-websocket" -version = "1.0.0" +version = "1.1.0" description = "Simple WebSocket server and client for Python" optional = false python-versions = ">=3.6" files = [ - {file = "simple-websocket-1.0.0.tar.gz", hash = "sha256:17d2c72f4a2bd85174a97e3e4c88b01c40c3f81b7b648b0cc3ce1305968928c8"}, - {file = "simple_websocket-1.0.0-py3-none-any.whl", hash = "sha256:1d5bf585e415eaa2083e2bcf02a3ecf91f9712e7b3e6b9fa0b461ad04e0837bc"}, + {file = "simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c"}, + {file = "simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4"}, ] [package.dependencies] wsproto = "*" [package.extras] +dev = ["flake8", "pytest", "pytest-cov", "tox"] docs = ["sphinx"] [[package]] @@ -8767,6 +8831,38 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "storage3" +version = "0.8.1" +description = "Supabase Storage client for Python." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"}, + {file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +python-dateutil = ">=2.8.2,<3.0.0" +typing-extensions = ">=4.2.0,<5.0.0" + +[[package]] +name = "strenum" +version = "0.4.15" +description = "An Enum that inherits from str." +optional = false +python-versions = "*" +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[package.extras] +docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] +release = ["twine"] +test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] + [[package]] name = "strictyaml" version = "1.7.3" @@ -8781,6 +8877,40 @@ files = [ [package.dependencies] python-dateutil = ">=2.6.0" +[[package]] +name = "supabase" +version = "2.8.1" +description = "Supabase client for Python." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"}, + {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"}, +] + +[package.dependencies] +gotrue = ">=2.7.0,<3.0.0" +httpx = ">=0.24,<0.28" +postgrest = ">=0.17.0,<0.18.0" +realtime = ">=2.0.0,<3.0.0" +storage3 = ">=0.8.0,<0.9.0" +supafunc = ">=0.6.0,<0.7.0" +typing-extensions = ">=4.12.2,<5.0.0" + +[[package]] +name = "supafunc" +version = "0.6.1" +description = "Library for Supabase Functions" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"}, + {file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} + [[package]] name = "sympy" version = "1.13.3" @@ -8855,13 +8985,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1246" +version = "3.0.1247" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1246.tar.gz", hash = "sha256:a724d53d5dd6c68beff9ae1c2eb7737889d67f7724819ec5189e878e5322940b"}, - {file = "tencentcloud_sdk_python_common-3.0.1246-py2.py3-none-any.whl", hash = "sha256:e82a103cc4f1a8d07a83600604eba89e46899b005328b921dbcba9f73b3e819b"}, + {file = "tencentcloud-sdk-python-common-3.0.1247.tar.gz", hash = "sha256:1467ac3eaaa5b5d299570ba781903debc4be32dbb3f0f39929a357531ab89170"}, + {file = "tencentcloud_sdk_python_common-3.0.1247-py2.py3-none-any.whl", hash = "sha256:9829d2299c85a2494d6d816247345e98abd2f936cd309e1f67847243f5235091"}, ] [package.dependencies] @@ -8869,17 +8999,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1246" +version = "3.0.1247" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1246.tar.gz", hash = "sha256:099db4a80e7788e1ca09b8be45758e47a6c5976e304857f26621f958cca5b39b"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1246-py2.py3-none-any.whl", hash = "sha256:14ad0a6787bb579a9270e6fe6ee81131a286c2f1b56cac8d7de0056983b6548f"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1247.tar.gz", hash = "sha256:85b7332ec55f891a3b4d776e6b30ee2a44cc08c70b689615805aadff6e424fdd"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1247-py2.py3-none-any.whl", hash = "sha256:69fdb886616e53ce02e848e5a1a8b36922db731457b07365f230ffb0aa472b5b"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1246" +tencentcloud-sdk-python-common = "3.0.1247" [[package]] name = "threadpoolctl" @@ -9855,97 +9985,83 @@ test = ["websockets"] [[package]] name = "websockets" -version = "13.1" +version = "12.0" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.8" files = [ - {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, - {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, - {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, - {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, - {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, - {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, - {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, - {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, - {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, - {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, - {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, - {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, - {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, - {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, - {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, - {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, - {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, - {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, - {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, - {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, - {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, - {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, - {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, - {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, - {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, - {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, - {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, - {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, - {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, - {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, - {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, - {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, - {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, - {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, - {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, - {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, - {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, - {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, - {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, - {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, - {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, - {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, - {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, - {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, - {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, - {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, - {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, - {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, - {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, - {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, - {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, - {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, - {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, - {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, - {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, - {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, + {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, + {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, + {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, + {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, + {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, + {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, + {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, + {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, + {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, + {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, + {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, + {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, + {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, + {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, + {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, + {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, + {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, + {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, + {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, + {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, + {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, + {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, + {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, + {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, + {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, + {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, + {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, + {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, + {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, + {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, + {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, + {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, + {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, ] [[package]] @@ -10493,4 +10609,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "d324192116c4b243e504d57f4605b79c46592a976201d903b16a910b71d84b57" +content-hash = "b1152e5ef8d1980cf4ac7fd1ffee60123c582ab7bfddf8c2e281baa70e61c2d5" diff --git a/api/pyproject.toml b/api/pyproject.toml index 11bcc255d7cd98..da09932e8934ee 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -225,6 +225,7 @@ cos-python-sdk-v5 = "1.9.30" esdk-obs-python = "3.24.6.1" google-cloud-storage = "2.16.0" oss2 = "2.18.5" +supabase = "~2.8.1" tos = "~2.7.1" ############################################################ From 93af87a9e0c7010e81e22ea8488fd0fa2bd33913 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Sat, 12 Oct 2024 09:28:45 +0800 Subject: [PATCH 03/25] fix: move exception to debug mode (#9258) --- api/core/app/task_pipeline/message_cycle_manage.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 3a1d1b227f62f4..236eebf0b85ff6 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -1,8 +1,10 @@ +import logging from threading import Thread from typing import Optional, Union from flask import Flask, current_app +from configs import dify_config from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -83,7 +85,9 @@ def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) conversation.name = name except Exception as e: - logging.exception(f"generate conversation name failed: {e}") + if dify_config.DEBUG: + logging.exception(f"generate conversation name failed: {e}") + pass db.session.merge(conversation) db.session.commit() From 1206b1eb96d108c516f1171e938d0e5dc3fc1a80 Mon Sep 17 00:00:00 2001 From: NFish Date: Sat, 12 Oct 2024 11:32:40 +0800 Subject: [PATCH 04/25] fix: add new domain to whitelist (#9265) --- web/middleware.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/middleware.ts b/web/middleware.ts index 0c5817445fd1a1..e0f8f3782f78d5 100644 --- a/web/middleware.ts +++ b/web/middleware.ts @@ -1,7 +1,7 @@ import type { NextRequest } from 'next/server' import { NextResponse } from 'next/server' -const NECESSARY_DOMAIN = '*.sentry.io http://localhost:* http://127.0.0.1:* https://analytics.google.com https://googletagmanager.com https://api.github.com' +const NECESSARY_DOMAIN = '*.sentry.io http://localhost:* http://127.0.0.1:* https://analytics.google.com googletagmanager.com *.googletagmanager.com https://www.google-analytics.com https://api.github.com' export function middleware(request: NextRequest) { const isWhiteListEnabled = !!process.env.NEXT_PUBLIC_CSP_WHITELIST && process.env.NODE_ENV === 'production' From d9773c963fceed9af750d46a2d04ddd88bc26f37 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Sat, 12 Oct 2024 17:37:01 +0800 Subject: [PATCH 05/25] chore: fix the misclassification of the opensearch-py package (#9266) --- api/poetry.lock | 32 +++++++++++++++++++------------- api/pyproject.toml | 2 +- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index fd6b98191df02c..52e0f3031151de 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -732,7 +732,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -847,7 +847,7 @@ name = "botocore" version = "1.35.38" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">=3.8" +python-versions = ">= 3.8" files = [ {file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"}, {file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"}, @@ -1068,7 +1068,7 @@ name = "build" version = "1.2.2.post1" description = "A simple, correct Python build frontend" optional = false -python-versions = ">=3.8" +python-versions = ">= 3.8" files = [ {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, @@ -3385,7 +3385,7 @@ name = "gotrue" version = "2.9.2" description = "Python Client Library for Supabase Auth" optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8,<4.0" files = [ {file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"}, {file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"}, @@ -4415,7 +4415,7 @@ name = "langfuse" version = "2.51.5" description = "A client library for accessing langfuse" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = ">=3.8.1,<4.0" files = [ {file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"}, {file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"}, @@ -4440,7 +4440,7 @@ name = "langsmith" version = "0.1.134" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = ">=3.8.1,<4.0" files = [ {file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"}, {file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"}, @@ -6429,7 +6429,7 @@ name = "postgrest" version = "0.17.1" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8,<4.0" files = [ {file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"}, {file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"}, @@ -7746,7 +7746,7 @@ name = "realtime" version = "2.0.2" description = "" optional = false -python-versions = "<4.0,>=3.9" +python-versions = ">=3.9,<4.0" files = [ {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, @@ -8173,7 +8173,7 @@ name = "s3transfer" version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">=3.8" +python-versions = ">= 3.8" files = [ {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, @@ -8836,7 +8836,7 @@ name = "storage3" version = "0.8.1" description = "Supabase Storage client for Python." optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8,<4.0" files = [ {file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"}, {file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"}, @@ -8882,7 +8882,7 @@ name = "supabase" version = "2.8.1" description = "Supabase client for Python." optional = false -python-versions = "<4.0,>=3.9" +python-versions = ">=3.9,<4.0" files = [ {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"}, {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"}, @@ -8902,7 +8902,7 @@ name = "supafunc" version = "0.6.1" description = "Library for Supabase Functions" optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8,<4.0" files = [ {file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"}, {file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"}, @@ -10455,31 +10455,37 @@ python-versions = ">=3.8" files = [ {file = "zope.interface-7.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2bd9e9f366a5df08ebbdc159f8224904c1c5ce63893984abb76954e6fbe4381a"}, {file = "zope.interface-7.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:661d5df403cd3c5b8699ac480fa7f58047a3253b029db690efa0c3cf209993ef"}, + {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91b6c30689cfd87c8f264acb2fc16ad6b3c72caba2aec1bf189314cf1a84ca33"}, {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b6a4924f5bad9fe21d99f66a07da60d75696a136162427951ec3cb223a5570d"}, {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a3c00b35f6170be5454b45abe2719ea65919a2f09e8a6e7b1362312a872cd3"}, {file = "zope.interface-7.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b936d61dbe29572fd2cfe13e30b925e5383bed1aba867692670f5a2a2eb7b4e9"}, {file = "zope.interface-7.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ac20581fc6cd7c754f6dff0ae06fedb060fa0e9ea6309d8be8b2701d9ea51c4"}, {file = "zope.interface-7.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:848b6fa92d7c8143646e64124ed46818a0049a24ecc517958c520081fd147685"}, + {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec1ef1fdb6f014d5886b97e52b16d0f852364f447d2ab0f0c6027765777b6667"}, {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bcff5c09d0215f42ba64b49205a278e44413d9bf9fa688fd9e42bfe472b5f4f"}, {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07add15de0cc7e69917f7d286b64d54125c950aeb43efed7a5ea7172f000fbc1"}, {file = "zope.interface-7.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:9940d5bc441f887c5f375ec62bcf7e7e495a2d5b1da97de1184a88fb567f06af"}, {file = "zope.interface-7.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f245d039f72e6f802902375755846f5de1ee1e14c3e8736c078565599bcab621"}, {file = "zope.interface-7.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6159e767d224d8f18deff634a1d3722e68d27488c357f62ebeb5f3e2f5288b1f"}, + {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e956b1fd7f3448dd5e00f273072e73e50dfafcb35e4227e6d5af208075593c9"}, {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff115ef91c0eeac69cd92daeba36a9d8e14daee445b504eeea2b1c0b55821984"}, {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bec001798ab62c3fc5447162bf48496ae9fba02edc295a9e10a0b0c639a6452e"}, {file = "zope.interface-7.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:124149e2d42067b9c6597f4dafdc7a0983d0163868f897b7bb5dc850b14f9a87"}, {file = "zope.interface-7.1.0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:9733a9a0f94ef53d7aa64661811b20875b5bc6039034c6e42fb9732170130573"}, {file = "zope.interface-7.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5fcf379b875c610b5a41bc8a891841533f98de0520287d7f85e25386cd10d3e9"}, + {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0a45b5af9f72c805ee668d1479480ca85169312211bed6ed18c343e39307d5f"}, {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af4a12b459a273b0b34679a5c3dc5e34c1847c3dd14a628aa0668e19e638ea2"}, {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a735f82d2e3ed47ca01a20dfc4c779b966b16352650a8036ab3955aad151ed8a"}, {file = "zope.interface-7.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:5501e772aff595e3c54266bc1bfc5858e8f38974ce413a8f1044aae0f32a83a3"}, {file = "zope.interface-7.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec59fe53db7d32abb96c6d4efeed84aab4a7c38c62d7a901a9b20c09dd936e7a"}, {file = "zope.interface-7.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e53c291debef523b09e1fe3dffe5f35dde164f1c603d77f770b88a1da34b7ed6"}, + {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:711eebc77f2092c6a8b304bad0b81a6ce3cf5490b25574e7309fbc07d881e3af"}, {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a00ead2e24c76436e1b457a5132d87f83858330f6c923640b7ef82d668525d1"}, {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e28ea0bc4b084fc93a483877653a033062435317082cdc6388dec3438309faf"}, {file = "zope.interface-7.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:27cfb5205d68b12682b6e55ab8424662d96e8ead19550aad0796b08dd2c9a45e"}, {file = "zope.interface-7.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9e3e48f3dea21c147e1b10c132016cb79af1159facca9736d231694ef5a740a8"}, {file = "zope.interface-7.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a99240b1d02dc469f6afbe7da1bf617645e60290c272968f4e53feec18d7dce8"}, + {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc8a318162123eddbdf22fcc7b751288ce52e4ad096d3766ff1799244352449d"}, {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7b25db127db3e6b597c5f74af60309c4ad65acd826f89609662f0dc33a54728"}, {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a29ac607e970b5576547f0e3589ec156e04de17af42839eedcf478450687317"}, {file = "zope.interface-7.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:a14c9decf0eb61e0892631271d500c1e306c7b6901c998c7035e194d9150fdd1"}, @@ -10609,4 +10615,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b1152e5ef8d1980cf4ac7fd1ffee60123c582ab7bfddf8c2e281baa70e61c2d5" +content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21" diff --git a/api/pyproject.toml b/api/pyproject.toml index da09932e8934ee..277d1690c72702 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -207,7 +207,6 @@ matplotlib = "~3.8.2" newspaper3k = "0.2.8" nltk = "3.8.1" numexpr = "~2.9.0" -opensearch-py = "2.4.0" qrcode = "~7.4.2" twilio = "~9.0.4" vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } @@ -238,6 +237,7 @@ alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" clickhouse-connect = "~0.7.16" elasticsearch = "8.14.0" +opensearch-py = "2.4.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" From 29188e0562384d3c8e6cfbd0cf5c4ffd5d57e8aa Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 12 Oct 2024 17:48:59 +0800 Subject: [PATCH 06/25] chore: use cache instead of re-querying node record during workflow execution (#9280) --- .../advanced_chat/generate_task_pipeline.py | 3 +++ .../apps/workflow/generate_task_pipeline.py | 3 +++ .../task_pipeline/workflow_cycle_manage.py | 21 +++++++------------ 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6bf684f8e40a3e..fd63c7787fa631 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -56,6 +56,7 @@ from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, + WorkflowNodeExecution, WorkflowRunStatus, ) @@ -72,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def __init__( self, @@ -115,6 +117,7 @@ def __init__( } self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} self._conversation_name_generate_thread = None diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 3afc5053673e3b..7c53556e43bc48 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -52,6 +52,7 @@ Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, ) @@ -69,6 +70,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def __init__( self, @@ -103,6 +105,7 @@ def __init__( } self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4fc587db77c145..f48ae9c01e6de4 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -57,6 +57,7 @@ class WorkflowCycleManage: _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def _handle_workflow_run_start(self) -> WorkflowRun: max_sequence = ( @@ -251,6 +252,8 @@ def _handle_node_execution_start( db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + return workflow_node_execution def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: @@ -275,9 +278,10 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() - db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) + return workflow_node_execution def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: @@ -300,9 +304,10 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() - db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) + return workflow_node_execution ################################################# @@ -678,17 +683,7 @@ def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNo :param node_execution_id: workflow node execution id :return: """ - workflow_node_execution = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, - WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, - WorkflowNodeExecution.workflow_id == self._workflow.id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.node_execution_id == node_execution_id, - ) - .first() - ) + workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) if not workflow_node_execution: raise Exception(f"Workflow node execution not found: {node_execution_id}") From 23ce1fb1ba0baa523fe3928bf8fc2fe0dacf3f5f Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 12 Oct 2024 18:30:46 +0800 Subject: [PATCH 07/25] chore: optimize the trace ops slow queries on node executions. (#9282) --- .../task_pipeline/workflow_cycle_manage.py | 55 ++++++++++++++----- api/core/ops/langfuse_trace/langfuse_trace.py | 41 ++++++++------ .../ops/langsmith_trace/langsmith_trace.py | 41 ++++++++------ 3 files changed, 92 insertions(+), 45 deletions(-) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index f48ae9c01e6de4..b8f5ac260340e5 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -266,20 +266,35 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = ( + execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() + finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( + { + WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, + WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, + WorkflowNodeExecution.execution_metadata: execution_metadata, + WorkflowNodeExecution.finished_at: finished_at, + WorkflowNodeExecution.elapsed_time: elapsed_time, + } + ) db.session.commit() db.session.close() + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) return workflow_node_execution @@ -294,17 +309,31 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( + { + WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, + WorkflowNodeExecution.error: event.error, + WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, + WorkflowNodeExecution.finished_at: finished_at, + WorkflowNodeExecution.elapsed_time: elapsed_time, + } + ) + + db.session.commit() + db.session.close() workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = event.error - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() - - db.session.commit() - db.session.close() + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 171e34f8cb48f0..0cba40c51a0d19 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -110,26 +110,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): self.add_trace(langfuse_trace_data=trace_data) # through workflow_run_id get all_nodes_execution - workflow_nodes_executions = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) + workflow_nodes_execution_id_records = ( + db.session.query(WorkflowNodeExecution.id) .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) .all() ) - for node_execution in workflow_nodes_executions: + for node_execution_id_record in workflow_nodes_execution_id_records: + node_execution = ( + db.session.query( + WorkflowNodeExecution.id, + WorkflowNodeExecution.tenant_id, + WorkflowNodeExecution.app_id, + WorkflowNodeExecution.title, + WorkflowNodeExecution.node_type, + WorkflowNodeExecution.status, + WorkflowNodeExecution.inputs, + WorkflowNodeExecution.outputs, + WorkflowNodeExecution.created_at, + WorkflowNodeExecution.elapsed_time, + WorkflowNodeExecution.process_data, + WorkflowNodeExecution.execution_metadata, + ) + .filter(WorkflowNodeExecution.id == node_execution_id_record.id) + .first() + ) + + if not node_execution: + continue + node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 37cbea13fd8733..ad450504057bef 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -100,26 +100,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): self.add_run(langsmith_run) # through workflow_run_id get all_nodes_execution - workflow_nodes_executions = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) + workflow_nodes_execution_id_records = ( + db.session.query(WorkflowNodeExecution.id) .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) .all() ) - for node_execution in workflow_nodes_executions: + for node_execution_id_record in workflow_nodes_execution_id_records: + node_execution = ( + db.session.query( + WorkflowNodeExecution.id, + WorkflowNodeExecution.tenant_id, + WorkflowNodeExecution.app_id, + WorkflowNodeExecution.title, + WorkflowNodeExecution.node_type, + WorkflowNodeExecution.status, + WorkflowNodeExecution.inputs, + WorkflowNodeExecution.outputs, + WorkflowNodeExecution.created_at, + WorkflowNodeExecution.elapsed_time, + WorkflowNodeExecution.process_data, + WorkflowNodeExecution.execution_metadata, + ) + .filter(WorkflowNodeExecution.id == node_execution_id_record.id) + .first() + ) + + if not node_execution: + continue + node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id From c6b74daa0a87bd02bdd6e82c2cc00e45772a793f Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Sat, 12 Oct 2024 18:47:59 +0800 Subject: [PATCH 08/25] Fix/s3 iam add region name (#7819) --- api/extensions/storage/aws_s3_storage.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index fede683aa7bb13..38f823763fa4f7 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Generator from contextlib import closing @@ -8,6 +9,8 @@ from extensions.storage.base_storage import BaseStorage +logger = logging.getLogger(__name__) + class AwsS3Storage(BaseStorage): """Implementation for Amazon Web Services S3 storage.""" @@ -17,9 +20,14 @@ def __init__(self, app: Flask): app_config = self.app.config self.bucket_name = app_config.get("S3_BUCKET_NAME") if app_config.get("S3_USE_AWS_MANAGED_IAM"): + logger.info("Using AWS managed IAM role for S3") + session = boto3.Session() - self.client = session.client("s3") + region_name = app_config.get("S3_REGION") + self.client = session.client(service_name="s3", region_name=region_name) else: + logger.info("Using ak and sk for S3") + self.client = boto3.client( "s3", aws_secret_access_key=app_config.get("S3_SECRET_KEY"), From 793205afc547264151ef6ac428f485422adf6daf Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Sat, 12 Oct 2024 21:24:43 +0800 Subject: [PATCH 09/25] Feat: rerank model verification in front end (#9271) --- .../params-config/config-content.tsx | 49 +++++++++++++---- .../common/retrieval-param-config/index.tsx | 49 +++++++++++++---- .../workflow/hooks/use-workflow-start-run.tsx | 55 +++++++++++++++++++ .../components/workflow/hooks/use-workflow.ts | 28 ++++++++++ web/i18n/en-US/workflow.ts | 1 + web/i18n/zh-Hans/workflow.ts | 1 + 6 files changed, 159 insertions(+), 24 deletions(-) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 7f83a14d58accd..f5561215180738 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -1,6 +1,6 @@ 'use client' -import { memo, useEffect, useMemo } from 'react' +import { memo, useCallback, useEffect, useMemo } from 'react' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import WeightedScore from './weighted-score' @@ -11,7 +11,7 @@ import type { DatasetConfigs, } from '@/models/debug' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelConfig } from '@/app/components/workflow/types' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import Tooltip from '@/app/components/base/tooltip' @@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets' import cn from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' +import Toast from '@/app/components/base/toast' type Props = { datasetConfigs: DatasetConfigs @@ -60,6 +61,24 @@ const ConfigContent: FC = ({ modelList: rerankModelList, defaultModel: rerankDefaultModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentModel, rerankDefaultModel, t]) + const rerankModel = (() => { if (datasetConfigs.reranking_model?.reranking_provider_name) { return { @@ -231,16 +250,22 @@ const ConfigContent: FC = ({
{ selectedDatasetsMode.allEconomic && ( - { - onChange({ - ...datasetConfigs, - reranking_enable: v, - }) - }} - /> +
+ { + onChange({ + ...datasetConfigs, + reranking_enable: v, + }) + }} + /> +
) }
{t('common.modelProvider.rerankModel.key')}
diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 323e47f3b4ae24..9d48d56a8dc511 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import cn from '@/utils/classnames' @@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { DEFAULT_WEIGHTED_SCORE, @@ -19,6 +19,7 @@ import { WeightedScoreEnum, } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' +import Toast from '@/app/components/base/toast' type Props = { type: RETRIEVE_METHOD @@ -38,6 +39,24 @@ const RetrievalParamConfig: FC = ({ defaultModel: rerankDefaultModel, modelList: rerankModelList, } = useModelListAndDefaultModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentModel, rerankDefaultModel, t]) + const isHybridSearch = type === RETRIEVE_METHOD.hybrid const rerankModel = (() => { @@ -99,16 +118,22 @@ const RetrievalParamConfig: FC = ({
{canToggleRerankModalEnable && ( - { - onChange({ - ...value, - reranking_enable: v, - }) - }} - /> +
+ { + onChange({ + ...value, + reranking_enable: v, + }) + }} + disabled={!currentModel} + /> +
)}
{t('common.modelProvider.rerankModel.key')} diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index b2b1c69975658a..77e959b573ba54 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,17 +1,25 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' +import { useTranslation } from 'react-i18next' import { useWorkflowStore } from '../store' import { BlockEnum, WorkflowRunningStatus, } from '../types' +import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' +import type { Node } from '../types' +import { useWorkflow } from './use-workflow' import { useIsChatMode, useNodesSyncDraft, useWorkflowInteractions, useWorkflowRun, } from './index' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks' +import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' +import Toast from '@/app/components/base/toast' export const useWorkflowStartRun = () => { const store = useStoreApi() @@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => { const isChatMode = useIsChatMode() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleRun } = useWorkflowRun() + const { isFromStartNode } = useWorkflow() const { doSyncWorkflowDraft } = useNodesSyncDraft() + const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault + const { t } = useTranslation() + const { + modelList: rerankModelList, + defaultModel: rerankDefaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) const handleWorkflowStartRunInWorkflow = useCallback(async () => { const { @@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => { const { getNodes } = store.getState() const nodes = getNodes() const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const knowledgeRetrievalNodes = nodes.filter((node: Node) => + node.data.type === BlockEnum.KnowledgeRetrieval, + ) const startVariables = startNode?.data.variables || [] const fileSettings = featuresStore!.getState().features.file const { @@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => { setShowEnvPanel, } = workflowStore.getState() + if (knowledgeRetrievalNodes.length > 0) { + for (const node of knowledgeRetrievalNodes) { + if (isFromStartNode(node.id)) { + const res = checkKnowledgeRetrievalValid(node.data, t) + if (!res.isValid || !currentModel || !rerankDefaultModel) { + const errorMessage = res.errorMessage + if (errorMessage) { + Toast.notify({ + type: 'error', + message: errorMessage, + }) + return false + } + else { + Toast.notify({ + type: 'error', + message: t('appDebug.datasetConfig.rerankModelRequired'), + }) + return false + } + } + } + } + } + setShowEnvPanel(false) if (showDebugAndPreviewPanel) { diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index b201b28b88d14b..ec7ce66e5fdce6 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -235,6 +235,33 @@ export const useWorkflow = () => { return nodes.filter(node => node.parentId === nodeId) }, [store]) + const isFromStartNode = useCallback((nodeId: string) => { + const { getNodes } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId) + + if (!currentNode) + return false + + if (currentNode.data.type === BlockEnum.Start) + return true + + const checkPreviousNodes = (node: Node) => { + const previousNodes = getBeforeNodeById(node.id) + + for (const prevNode of previousNodes) { + if (prevNode.data.type === BlockEnum.Start) + return true + if (checkPreviousNodes(prevNode)) + return true + } + + return false + } + + return checkPreviousNodes(currentNode) + }, [store, getBeforeNodeById]) + const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => { const { getNodes, setNodes } = store.getState() const afterNodes = getAfterNodesInSameBranch(nodeId) @@ -389,6 +416,7 @@ export const useWorkflow = () => { checkParallelLimit, checkNestedParallelLimit, isValidConnection, + isFromStartNode, formatTimeFromNow, getNode, getBeforeNodeById, diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index a7e768911ffecd..d5ab6eb72894f3 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -172,6 +172,7 @@ const translation = { }, errorMsg: { fieldRequired: '{{field}} is required', + rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.', authRequired: 'Authorization is required', invalidJson: '{{field}} is invalid JSON', fields: { diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 3579ec5df3440e..4959a87be7db77 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -172,6 +172,7 @@ const translation = { }, errorMsg: { fieldRequired: '{{field}} 不能为空', + rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。', authRequired: '请先授权', invalidJson: '{{field}} 是非法的 JSON', fields: { From 2ec6ffe478984539209725b754631565515da8bd Mon Sep 17 00:00:00 2001 From: Shili Cao Date: Sat, 12 Oct 2024 23:24:17 +0800 Subject: [PATCH 10/25] feat:support baidu vector db (#9185) --- api/.env.example | 9 + api/commands.py | 8 + .../middleware/vdb/baidu_vector_config.py | 45 +++ api/controllers/console/datasets/datasets.py | 2 + api/core/rag/datasource/vdb/baidu/__init__.py | 0 .../rag/datasource/vdb/baidu/baidu_vector.py | 272 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 4 + api/core/rag/datasource/vdb/vector_type.py | 1 + api/poetry.lock | 47 ++- api/pyproject.toml | 1 + .../vdb/__mock/baiduvectordb.py | 154 ++++++++++ .../integration_tests/vdb/baidu/__init__.py | 0 .../integration_tests/vdb/baidu/test_baidu.py | 36 +++ docker/.env.example | 9 + docker/docker-compose.yaml | 7 + 15 files changed, 582 insertions(+), 13 deletions(-) create mode 100644 api/configs/middleware/vdb/baidu_vector_config.py create mode 100644 api/core/rag/datasource/vdb/baidu/__init__.py create mode 100644 api/core/rag/datasource/vdb/baidu/baidu_vector.py create mode 100644 api/tests/integration_tests/vdb/__mock/baiduvectordb.py create mode 100644 api/tests/integration_tests/vdb/baidu/__init__.py create mode 100644 api/tests/integration_tests/vdb/baidu/test_baidu.py diff --git a/api/.env.example b/api/.env.example index 7b5e4950c82ddd..3f88fb3cdf5b09 100644 --- a/api/.env.example +++ b/api/.env.example @@ -208,6 +208,15 @@ OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +# Baidu configuration +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/commands.py b/api/commands.py index 7ef4aed7f77664..dbcd8a744d3a45 100644 --- a/api/commands.py +++ b/api/commands.py @@ -347,6 +347,14 @@ def migrate_knowledge_vector_database(): index_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.BAIDU: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": VectorType.BAIDU, + "vector_store": {"class_prefix": collection_name}, + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py new file mode 100644 index 00000000000000..44742c2e2f4349 --- /dev/null +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -0,0 +1,45 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class BaiduVectorDBConfig(BaseSettings): + """ + Configuration settings for Baidu Vector Database + """ + + BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", + default=None, + ) + + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( + description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", + default=30000, + ) + + BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + description="Account for authenticating with the Baidu Vector Database", + default=None, + ) + + BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + description="API key for authenticating with the Baidu Vector Database service", + default=None, + ) + + BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + description="Name of the specific Baidu Vector Database to connect to", + default=None, + ) + + BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( + description="Number of shards for the Baidu Vector Database (default is 1)", + default=1, + ) + + BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( + description="Number of replicas for the Baidu Vector Database (default is 3)", + default=3, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 9561fd8b70e4b9..102089bf071ac2 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -617,6 +617,7 @@ def get(self): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( @@ -653,6 +654,7 @@ def get(self, vector_type): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/core/rag/datasource/vdb/baidu/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py new file mode 100644 index 00000000000000..543cfa67b35409 --- /dev/null +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -0,0 +1,272 @@ +import json +import time +import uuid +from typing import Any + +from pydantic import BaseModel, model_validator +from pymochow import MochowClient +from pymochow.auth.bce_credentials import BceCredentials +from pymochow.configuration import Configuration +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row + +from configs import dify_config +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class BaiduConfig(BaseModel): + endpoint: str + connection_timeout_in_mills: int = 30 * 1000 + account: str + api_key: str + database: str + index_type: str = "HNSW" + metric_type: str = "L2" + shard: int = 1 + replicas: int = 3 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["endpoint"]: + raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") + if not values["account"]: + raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required") + if not values["api_key"]: + raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required") + if not values["database"]: + raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required") + return values + + +class BaiduVector(BaseVector): + field_id: str = "id" + field_vector: str = "vector" + field_text: str = "text" + field_metadata: str = "metadata" + field_app_id: str = "app_id" + field_annotation_id: str = "annotation_id" + index_vector: str = "vector_idx" + + def __init__(self, collection_name: str, config: BaiduConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._db = self._init_database() + + def get_type(self) -> str: + return VectorType.BAIDU + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self._create_table(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + total_count = len(documents) + batch_size = 1000 + + # upsert texts and embeddings batch by batch + table = self._db.table(self._collection_name) + for start in range(0, total_count, batch_size): + end = min(start + batch_size, total_count) + rows = [] + for i in range(start, end, 1): + row = Row( + id=metadatas[i].get("doc_id", str(uuid.uuid4())), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadatas[i]), + app_id=metadatas[i].get("app_id", ""), + annotation_id=metadatas[i].get("annotation_id", ""), + ) + rows.append(row) + table.upsert(rows=rows) + + # rebuild vector index after upsert finished + table.rebuild_index(self.index_vector) + while True: + time.sleep(1) + index = table.describe_index(self.index_vector) + if index.state == IndexState.NORMAL: + break + + def text_exists(self, id: str) -> bool: + res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) + if res and res.code == 0: + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + quoted_ids = [f"'{id}'" for id in ids] + self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + ) + res = self._db.table(self._collection_name).search( + anns=anns, + projections=[self.field_id, self.field_text, self.field_metadata], + retrieve_vector=True, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # baidu vector database doesn't support bm25 search on current version + return [] + + def _get_search_res(self, res, score_threshold): + docs = [] + for row in res.rows: + row_data = row.get("row", {}) + meta = row_data.get(self.field_metadata) + if meta is not None: + meta = json.loads(meta) + score = row.get("score", 0.0) + if score > score_threshold: + meta["score"] = score + doc = Document(page_content=row_data.get(self.field_text), metadata=meta) + docs.append(doc) + + return docs + + def delete(self) -> None: + self._db.drop_table(table_name=self._collection_name) + + def _init_client(self, config) -> MochowClient: + config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + client = MochowClient(config) + return client + + def _init_database(self): + exists = False + for db in self._client.list_databases(): + if db.database_name == self._client_config.database: + exists = True + break + # Create database if not existed + if exists: + return self._client.database(self._client_config.database) + else: + return self._client.create_database(database_name=self._client_config.database) + + def _table_existed(self) -> bool: + tables = self._db.list_table() + return any(table.table_name == self._collection_name for table in tables) + + def _create_table(self, dimension: int) -> None: + # Try to grab distributed lock and create table + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(table_exist_cache_key): + return + + if self._table_existed(): + return + + self.delete() + + # check IndexType and MetricType + index_type = None + for k, v in IndexType.__members__.items(): + if k == self._client_config.index_type: + index_type = v + if index_type is None: + raise ValueError("unsupported index_type") + metric_type = None + for k, v in MetricType.__members__.items(): + if k == self._client_config.metric_type: + metric_type = v + if metric_type is None: + raise ValueError("unsupported metric_type") + + # Construct field schema + fields = [] + fields.append( + Field( + self.field_id, + FieldType.STRING, + primary_key=True, + partition_key=True, + auto_increment=False, + not_null=True, + ) + ) + fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) + fields.append(Field(self.field_app_id, FieldType.STRING)) + fields.append(Field(self.field_annotation_id, FieldType.STRING)) + fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) + fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) + + # Construct vector index params + indexes = [] + indexes.append( + VectorIndex( + index_name="vector_idx", + index_type=index_type, + field="vector", + metric_type=metric_type, + params=HNSWParams(m=16, efconstruction=200), + ) + ) + + # Create table + self._db.create_table( + table_name=self._collection_name, + replication=self._client_config.replicas, + partition=Partition(partition_num=self._client_config.shard), + schema=Schema(fields=fields, indexes=indexes), + description="Table for Dify", + ) + + redis_client.set(table_exist_cache_key, 1, ex=3600) + + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break + + +class BaiduVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name)) + + return BaiduVector( + collection_name=collection_name, + config=BaiduConfig( + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, + database=dify_config.BAIDU_VECTOR_DB_DATABASE, + shard=dify_config.BAIDU_VECTOR_DB_SHARD, + replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 943b23870cc5cb..1f4a4d44a23eea 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -103,6 +103,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory return AnalyticdbVectorFactory + case VectorType.BAIDU: + from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory + + return BaiduVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index ba04ea879d9b43..996ff48615c901 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,3 +16,4 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + BAIDU = "baidu" diff --git a/api/poetry.lock b/api/poetry.lock index 52e0f3031151de..6565db27ad5725 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -732,7 +732,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -847,7 +847,7 @@ name = "botocore" version = "1.35.38" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"}, {file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"}, @@ -1068,7 +1068,7 @@ name = "build" version = "1.2.2.post1" description = "A simple, correct Python build frontend" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, @@ -3385,7 +3385,7 @@ name = "gotrue" version = "2.9.2" description = "Python Client Library for Supabase Auth" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"}, {file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"}, @@ -4415,7 +4415,7 @@ name = "langfuse" version = "2.51.5" description = "A client library for accessing langfuse" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ {file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"}, {file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"}, @@ -4440,7 +4440,7 @@ name = "langsmith" version = "0.1.134" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ {file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"}, {file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"}, @@ -6429,7 +6429,7 @@ name = "postgrest" version = "0.17.1" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"}, {file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"}, @@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymochow" +version = "1.3.1" +description = "Python SDK for mochow" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"}, + {file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"}, +] + +[package.dependencies] +future = "*" +orjson = "*" +requests = "*" + [[package]] name = "pymysql" version = "1.1.1" @@ -7746,7 +7762,7 @@ name = "realtime" version = "2.0.2" description = "" optional = false -python-versions = ">=3.9,<4.0" +python-versions = "<4.0,>=3.9" files = [ {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, @@ -8173,7 +8189,7 @@ name = "s3transfer" version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, @@ -8417,6 +8433,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -8836,7 +8857,7 @@ name = "storage3" version = "0.8.1" description = "Supabase Storage client for Python." optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"}, {file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"}, @@ -8882,7 +8903,7 @@ name = "supabase" version = "2.8.1" description = "Supabase client for Python." optional = false -python-versions = ">=3.9,<4.0" +python-versions = "<4.0,>=3.9" files = [ {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"}, {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"}, @@ -8902,7 +8923,7 @@ name = "supafunc" version = "0.6.1" description = "Library for Supabase Functions" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"}, {file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"}, @@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21" +content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774" diff --git a/api/pyproject.toml b/api/pyproject.toml index 277d1690c72702..594517771b34f2 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -242,6 +242,7 @@ oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" +pymochow = "1.3.1" qdrant-client = "1.7.3" tcvectordb = "1.3.2" tidb-vector = "0.0.9" diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py new file mode 100644 index 00000000000000..a8eaf42b7de1de --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -0,0 +1,154 @@ +import os + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex +from pymochow.model.table import Table +from requests.adapters import HTTPAdapter + + +class MockBaiduVectorDBClass: + def mock_vector_db_client( + self, + config=None, + adapter: HTTPAdapter = None, + ): + self._conn = None + self._config = None + + def list_databases(self, config=None) -> list[Database]: + return [ + Database( + conn=self._conn, + database_name="dify", + config=self._config, + ) + ] + + def create_database(self, database_name: str, config=None) -> Database: + return Database(conn=self._conn, database_name=database_name, config=config) + + def list_table(self, config=None) -> list[Table]: + return [] + + def drop_table(self, table_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def create_table( + self, + table_name: str, + replication: int, + partition: int, + schema, + enable_dynamic_field=False, + description: str = "", + config=None, + ) -> Table: + return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config) + + def describe_table(self, table_name: str, config=None) -> Table: + return Table( + self, + table_name, + 3, + 1, + None, + enable_dynamic_field=False, + description="table for dify", + config=config, + state=TableState.NORMAL, + ) + + def upsert(self, rows, config=None): + return {"code": 0, "msg": "operation success", "affectedCount": 1} + + def rebuild_index(self, index_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def describe_index(self, index_name: str, config=None): + return VectorIndex( + index_name=index_name, + index_type=IndexType.HNSW, + field="vector", + metric_type=MetricType.L2, + params=HNSWParams(m=16, efconstruction=200), + auto_build=False, + state=IndexState.NORMAL, + ) + + def query( + self, + primary_key, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "code": 0, + "msg": "Success", + } + + def delete(self, primary_key=None, partition_key=None, filter=None, config=None): + return {"code": 0, "msg": "Success"} + + def search( + self, + anns, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "rows": [ + { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "distance": 0.1, + "score": 0.5, + } + ], + "code": 0, + "msg": "Success", + } + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client) + monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases) + monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database) + monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table) + monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table) + monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table) + monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) + monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) + monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) + monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/tests/integration_tests/vdb/baidu/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py new file mode 100644 index 00000000000000..01a7f8853ac367 --- /dev/null +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -0,0 +1,36 @@ +from unittest.mock import MagicMock + +from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector +from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + +mock_client = MagicMock() +mock_client.list_databases.return_value = [{"name": "test"}] + + +class BaiduVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = BaiduVector( + "dify", + BaiduConfig( + endpoint="http://127.0.0.1:5287", + account="root", + api_key="dify", + database="dify", + shard=1, + replicas=3, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock): + BaiduVectorTest().run_all_tests() diff --git a/docker/.env.example b/docker/.env.example index 87d7709a1830af..c4eae46cb0215a 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200 ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic +# baidu vector configurations, only available when VECTOR_STORE is `baidu` +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # ------------------------------ # Knowledge Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 62d798a695967a..c046c17ef8f2b2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} + BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} + BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} + BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} + BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} + BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} From 70c5b23089584bcd700ddba3fa9d1ced4d0a57fd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 12 Oct 2024 23:27:11 +0800 Subject: [PATCH 11/25] chore: translate i18n files (#9284) Co-authored-by: YIXIAO0 <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- web/i18n/de-DE/workflow.ts | 1 + web/i18n/es-ES/workflow.ts | 1 + web/i18n/fa-IR/workflow.ts | 1 + web/i18n/fr-FR/workflow.ts | 1 + web/i18n/hi-IN/workflow.ts | 1 + web/i18n/it-IT/workflow.ts | 1 + web/i18n/ja-JP/workflow.ts | 1 + web/i18n/ko-KR/workflow.ts | 1 + web/i18n/pl-PL/workflow.ts | 1 + web/i18n/pt-BR/workflow.ts | 1 + web/i18n/ro-RO/workflow.ts | 1 + web/i18n/ru-RU/workflow.ts | 1 + web/i18n/tr-TR/workflow.ts | 1 + web/i18n/uk-UA/workflow.ts | 1 + web/i18n/vi-VN/workflow.ts | 1 + web/i18n/zh-Hant/workflow.ts | 1 + 16 files changed, 16 insertions(+) diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index c01d0e6f99418c..b6d0e8cde4ee21 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Neusortierungsmodell', }, invalidVariable: 'Ungültige Variable', + rerankModelRequired: 'Bevor Sie das Rerank-Modell aktivieren, bestätigen Sie bitte, dass das Modell in den Einstellungen erfolgreich konfiguriert wurde.', }, singleRun: { testRun: 'Testlauf ', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 2260631d0fa963..275149a0560946 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Modelo de reordenamiento', }, invalidVariable: 'Variable no válida', + rerankModelRequired: 'Antes de activar el modelo de reclasificación, confirme que el modelo se ha configurado correctamente en la configuración.', }, singleRun: { testRun: 'Ejecución de prueba', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index eb36dfdc888362..609f446b43e649 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'مدل مجدد رتبه‌بندی', }, invalidVariable: 'متغیر نامعتبر', + rerankModelRequired: 'قبل از روشن کردن Rerank Model، لطفا تأیید کنید که مدل با موفقیت در تنظیمات پیکربندی شده است.', }, singleRun: { testRun: 'اجرای آزمایشی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index 878d25804e3b36..068c41b853d83c 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Modèle de rerank', }, invalidVariable: 'Variable invalide', + rerankModelRequired: 'Avant d’activer le modèle de reclassement, veuillez confirmer que le modèle a été correctement configuré dans les paramètres.', }, singleRun: { testRun: 'Exécution de test', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index ac356c206758fb..e402200462d215 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -185,6 +185,7 @@ const translation = { rerankModel: 'पुनः रैंक मॉडल', }, invalidVariable: 'अमान्य वेरिएबल', + rerankModelRequired: 'Rerank मॉडल चालू करने से पहले, कृपया पुष्टि करें कि मॉडल को सेटिंग्स में सफलतापूर्वक कॉन्फ़िगर किया गया है।', }, singleRun: { testRun: 'परीक्षण रन', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 0427a45cd95f08..ce460ed252d7fd 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -187,6 +187,7 @@ const translation = { rerankModel: 'Modello Rerank', }, invalidVariable: 'Variabile non valida', + rerankModelRequired: 'Prima di attivare il modello di reranking, conferma che il modello è stato configurato correttamente nelle impostazioni.', }, singleRun: { testRun: 'Esecuzione Test ', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 48c20196016b03..2906f7ef8c90e5 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Rerankモデル', }, invalidVariable: '無効な変数', + rerankModelRequired: 'モデルの再ランク付けをオンにする前に、モデルが設定で正常に構成されていることを確認してください。', }, singleRun: { testRun: 'テスト実行', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index 4a97943790903a..99d5c47c0bffb8 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: '재정렬 모델', }, invalidVariable: '잘못된 변수', + rerankModelRequired: 'Rerank Model을 켜기 전에 설정에서 모델이 성공적으로 구성되었는지 확인하십시오.', }, singleRun: { testRun: '테스트 실행', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 41927668f7bcd8..b26c429fb19393 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Model rerank', }, invalidVariable: 'Nieprawidłowa zmienna', + rerankModelRequired: 'Przed włączeniem Rerank Model upewnij się, że model został pomyślnie skonfigurowany w ustawieniach.', }, singleRun: { testRun: 'Testowe uruchomienie ', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index 222fc788bfdfec..9092ccda3e41de 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Modelo de reordenação', }, invalidVariable: 'Variável inválida', + rerankModelRequired: 'Antes de ativar o modelo de reclassificação, confirme se o modelo foi configurado com sucesso nas configurações.', }, singleRun: { testRun: 'Execução de teste ', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index ac4b718b072abc..bb66169da89f82 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Model de rerankare', }, invalidVariable: 'Variabilă invalidă', + rerankModelRequired: 'Înainte de a activa modelul de reclasificare, vă rugăm să confirmați că modelul a fost configurat cu succes în setări.', }, singleRun: { testRun: 'Rulare de test ', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 19318638957f0f..5b2bc7e2902c54 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Модель переранжирования', }, invalidVariable: 'Неверная переменная', + rerankModelRequired: 'Перед включением модели повторного ранжирования убедитесь, что модель успешно настроена в настройках.', }, singleRun: { testRun: 'Тестовый запуск ', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index a33a3724ad48ea..8e1ce596309d86 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Yeniden Sıralama Modeli', }, invalidVariable: 'Geçersiz değişken', + rerankModelRequired: 'Yeniden Sıralama Modelini açmadan önce, lütfen ayarlarda modelin başarıyla yapılandırıldığını onaylayın.', }, singleRun: { testRun: 'Test Çalıştırma', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index e1bea99bcd2a88..f7747541ccfca1 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Модель повторного ранжування', }, invalidVariable: 'Недійсна змінна', + rerankModelRequired: 'Перед увімкненням Rerank Model, будь ласка, підтвердьте, що модель успішно налаштована в налаштуваннях.', }, singleRun: { testRun: 'Тестовий запуск', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 5e6f0e00631f37..aa9fbf865d6b3b 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Mô hình xếp hạng lại', }, invalidVariable: 'Biến không hợp lệ', + rerankModelRequired: 'Trước khi bật Mô hình xếp hạng lại, vui lòng xác nhận rằng mô hình đã được định cấu hình thành công trong cài đặt.', }, singleRun: { testRun: 'Chạy thử nghiệm ', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 8e1b7529fe18db..35ed68c437ed58 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -182,6 +182,7 @@ const translation = { rerankModel: 'Rerank 模型', }, invalidVariable: '無效的變量', + rerankModelRequired: '在開啟 Rerank 模型之前,請在設置中確認模型配置成功。', }, singleRun: { testRun: '測試運行', From dbfbc56de7e840b1e76a5a72d81dbd23c623f45a Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Sat, 12 Oct 2024 23:40:38 +0800 Subject: [PATCH 12/25] feat: refresh-token (#9286) Co-authored-by: NFish --- web/app/components/swr-initor.tsx | 24 ++++++-- web/app/signin/normalForm.tsx | 6 +- web/app/signin/userSSOForm.tsx | 11 +++- web/hooks/use-refresh-token.ts | 92 +++++++++++++++++++++++++++++++ web/package.json | 1 + web/service/common.ts | 17 +++++- web/utils/index.ts | 18 ++++++ web/yarn.lock | 5 ++ 8 files changed, 163 insertions(+), 11 deletions(-) create mode 100644 web/hooks/use-refresh-token.ts diff --git a/web/app/components/swr-initor.tsx b/web/app/components/swr-initor.tsx index afb4c58cbcda1d..85e05499e6d1ee 100644 --- a/web/app/components/swr-initor.tsx +++ b/web/app/components/swr-initor.tsx @@ -4,6 +4,7 @@ import { SWRConfig } from 'swr' import { useEffect, useState } from 'react' import type { ReactNode } from 'react' import { useRouter, useSearchParams } from 'next/navigation' +import useRefreshToken from '@/hooks/use-refresh-token' type SwrInitorProps = { children: ReactNode @@ -13,18 +14,31 @@ const SwrInitor = ({ }: SwrInitorProps) => { const router = useRouter() const searchParams = useSearchParams() - const consoleToken = searchParams.get('console_token') + const consoleToken = searchParams.get('access_token') + const refreshToken = searchParams.get('refresh_token') const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') + const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') const [init, setInit] = useState(false) + const { getNewAccessToken } = useRefreshToken() useEffect(() => { - if (!(consoleToken || consoleTokenFromLocalStorage)) + if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage)) { router.replace('/signin') + return + } + if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) + getNewAccessToken(consoleTokenFromLocalStorage, refreshTokenFromLocalStorage) - if (consoleToken) { - localStorage?.setItem('console_token', consoleToken!) - router.replace('/apps', { forceOptimisticNavigation: false } as any) + if (consoleToken && refreshToken) { + localStorage.setItem('console_token', consoleToken) + localStorage.setItem('refresh_token', refreshToken) + getNewAccessToken(consoleToken, refreshToken).then(() => { + router.replace('/apps', { forceOptimisticNavigation: false } as any) + }).catch(() => { + router.replace('/signin') + }) } + setInit(true) }, []) diff --git a/web/app/signin/normalForm.tsx b/web/app/signin/normalForm.tsx index 816df8007d9e7c..0ae4eb1f432106 100644 --- a/web/app/signin/normalForm.tsx +++ b/web/app/signin/normalForm.tsx @@ -11,6 +11,7 @@ import { IS_CE_EDITION, SUPPORT_MAIL_LOGIN, apiPrefix, emailRegex } from '@/conf import Button from '@/app/components/base/button' import { login, oauth } from '@/service/common' import { getPurifyHref } from '@/utils' +import useRefreshToken from '@/hooks/use-refresh-token' type IState = { formValid: boolean @@ -61,6 +62,7 @@ function reducer(state: IState, action: IAction) { const NormalForm = () => { const { t } = useTranslation() + const { getNewAccessToken } = useRefreshToken() const useEmailLogin = IS_CE_EDITION || SUPPORT_MAIL_LOGIN const router = useRouter() @@ -95,7 +97,9 @@ const NormalForm = () => { }, }) if (res.result === 'success') { - localStorage.setItem('console_token', res.data) + localStorage.setItem('console_token', res.data.access_token) + localStorage.setItem('refresh_token', res.data.refresh_token) + getNewAccessToken(res.data.access_token, res.data.refresh_token) router.replace('/apps') } else { diff --git a/web/app/signin/userSSOForm.tsx b/web/app/signin/userSSOForm.tsx index 9cd889a0a51291..e4b61413bc0ef9 100644 --- a/web/app/signin/userSSOForm.tsx +++ b/web/app/signin/userSSOForm.tsx @@ -7,6 +7,7 @@ import cn from '@/utils/classnames' import Toast from '@/app/components/base/toast' import { getUserOAuth2SSOUrl, getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' import Button from '@/app/components/base/button' +import useRefreshToken from '@/hooks/use-refresh-token' type UserSSOFormProps = { protocol: string @@ -15,8 +16,10 @@ type UserSSOFormProps = { const UserSSOForm: FC = ({ protocol, }) => { + const { getNewAccessToken } = useRefreshToken() const searchParams = useSearchParams() - const consoleToken = searchParams.get('console_token') + const consoleToken = searchParams.get('access_token') + const refreshToken = searchParams.get('refresh_token') const message = searchParams.get('message') const router = useRouter() @@ -25,8 +28,10 @@ const UserSSOForm: FC = ({ const [isLoading, setIsLoading] = useState(false) useEffect(() => { - if (consoleToken) { + if (refreshToken && consoleToken) { localStorage.setItem('console_token', consoleToken) + localStorage.setItem('refresh_token', refreshToken) + getNewAccessToken(consoleToken, refreshToken) router.replace('/apps') } @@ -36,7 +41,7 @@ const UserSSOForm: FC = ({ message, }) } - }, []) + }, [consoleToken, refreshToken, message, router]) const handleSSOLogin = () => { setIsLoading(true) diff --git a/web/hooks/use-refresh-token.ts b/web/hooks/use-refresh-token.ts new file mode 100644 index 00000000000000..3d8779636f4b4c --- /dev/null +++ b/web/hooks/use-refresh-token.ts @@ -0,0 +1,92 @@ +'use client' +import { useCallback, useEffect, useRef } from 'react' +import { jwtDecode } from 'jwt-decode' +import dayjs from 'dayjs' +import utc from 'dayjs/plugin/utc' +import { useRouter } from 'next/navigation' +import type { CommonResponse } from '@/models/common' +import { fetchNewToken } from '@/service/common' +import { fetchWithRetry } from '@/utils' + +dayjs.extend(utc) + +const useRefreshToken = () => { + const router = useRouter() + const timer = useRef() + const advanceTime = useRef(5 * 60 * 1000) + const interval = useRef(55 * 60 * 1000) + + const getExpireTime = useCallback((token: string) => { + if (!token) + return 0 + const decoded = jwtDecode(token) + return (decoded.exp || 0) * 1000 + }, []) + + const getCurrentTimeStamp = useCallback(() => { + return dayjs.utc().valueOf() + }, []) + + const handleError = useCallback(() => { + localStorage?.removeItem('is_refreshing') + localStorage?.removeItem('console_token') + localStorage?.removeItem('refresh_token') + localStorage?.removeItem('last_refresh_time') + router.replace('/signin') + }, []) + + const getNewAccessToken = useCallback(async (currentAccessToken: string, currentRefreshToken: string) => { + if (localStorage?.getItem('is_refreshing') === '1') + return null + const currentTokenExpireTime = getExpireTime(currentAccessToken) + let lastRefreshTime = parseInt(localStorage?.getItem('last_refresh_time') || '0') + lastRefreshTime = isNaN(lastRefreshTime) ? 0 : lastRefreshTime + if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime + && lastRefreshTime + interval.current < getCurrentTimeStamp()) { + localStorage?.setItem('is_refreshing', '1') + const [e, res] = await fetchWithRetry(fetchNewToken({ + body: { refresh_token: currentRefreshToken }, + }) as Promise) + if (e) { + handleError() + return e + } + const { access_token, refresh_token } = res.data + localStorage?.setItem('is_refreshing', '0') + localStorage?.setItem('last_refresh_time', getCurrentTimeStamp().toString()) + localStorage?.setItem('console_token', access_token) + localStorage?.setItem('refresh_token', refresh_token) + const newTokenExpireTime = getExpireTime(access_token) + timer.current = setTimeout(() => { + const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') + const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) + getNewAccessToken(consoleTokenFromLocalStorage, refreshTokenFromLocalStorage) + }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) + } + else { + const newTokenExpireTime = getExpireTime(currentAccessToken) + timer.current = setTimeout(() => { + const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') + const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) + getNewAccessToken(consoleTokenFromLocalStorage, refreshTokenFromLocalStorage) + }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) + } + return null + }, [getExpireTime, getCurrentTimeStamp, handleError]) + + useEffect(() => { + return () => { + clearTimeout(timer.current) + localStorage?.removeItem('is_refreshing') + localStorage?.removeItem('last_refresh_time') + } + }, []) + + return { + getNewAccessToken, + } +} + +export default useRefreshToken diff --git a/web/package.json b/web/package.json index 96b89c92315d9f..8a17997fe963cd 100644 --- a/web/package.json +++ b/web/package.json @@ -55,6 +55,7 @@ "immer": "^9.0.19", "js-audio-recorder": "^1.0.7", "js-cookie": "^3.0.1", + "jwt-decode": "^4.0.0", "katex": "^0.16.10", "lamejs": "^1.2.1", "lexical": "^0.16.0", diff --git a/web/service/common.ts b/web/service/common.ts index 3fbcde2a2ab3bd..bd3dbca0aaedbf 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -38,8 +38,21 @@ import type { import type { RETRIEVE_METHOD } from '@/types/app' import type { SystemFeatures } from '@/types/feature' -export const login: Fetcher }> = ({ url, body }) => { - return post(url, { body }) as Promise +type LoginSuccess = { + result: 'success' + data: { access_token: string;refresh_token: string } +} +type LoginFail = { + result: 'fail' + data: string +} +type LoginResponse = LoginSuccess | LoginFail +export const login: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const fetchNewToken: Fetcher }> = ({ body }) => { + return post('/refresh-token', { body }) as Promise } export const setup: Fetcher }> = ({ body }) => { diff --git a/web/utils/index.ts b/web/utils/index.ts index 8afd8afae70d6a..7aa6fef0a88c09 100644 --- a/web/utils/index.ts +++ b/web/utils/index.ts @@ -39,3 +39,21 @@ export const getPurifyHref = (href: string) => { return escape(href) } + +export async function fetchWithRetry(fn: Promise, retries = 3): Promise<[Error] | [null, T]> { + const [error, res] = await asyncRunSafe(fn) + if (error) { + if (retries > 0) { + const res = await fetchWithRetry(fn, retries - 1) + return res + } + else { + if (error instanceof Error) + return [error] + return [new Error('unknown error')] + } + } + else { + return [null, res] + } +} diff --git a/web/yarn.lock b/web/yarn.lock index 5d693ff786d8fa..4121752dd7de06 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -6205,6 +6205,11 @@ jsonc-eslint-parser@^2.0.4, jsonc-eslint-parser@^2.1.0: array-includes "^3.1.5" object.assign "^4.1.3" +jwt-decode@^4.0.0: + version "4.0.0" + resolved "https://registry.npmmirror.com/jwt-decode/-/jwt-decode-4.0.0.tgz#2270352425fd413785b2faf11f6e755c5151bd4b" + integrity sha512-+KJGIyHgkGuIq3IEBNftfhW/LfWhXUIY6OmyVWjliu5KH1y0fw7VQ8YndE2O4qZdMSd9SqbnC8GOcZEy0Om7sA== + katex@^0.16.0, katex@^0.16.10: version "0.16.10" resolved "https://registry.npmjs.org/katex/-/katex-0.16.10.tgz" From f73751843f9c88dea82821e6a109a0e6d225c4e5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 12 Oct 2024 23:46:30 +0800 Subject: [PATCH 13/25] Feat/implement-refresh-tokens (#9233) --- api/.env.example | 3 + api/app.py | 2 +- api/configs/feature/__init__.py | 15 +++-- api/controllers/console/auth/login.py | 23 +++++-- api/controllers/console/auth/oauth.py | 11 ++- api/controllers/console/setup.py | 4 +- api/libs/helper.py | 2 +- api/services/account_service.py | 97 +++++++++++++++++++++------ docker/.env.example | 3 + docker/docker-compose.yaml | 1 + 10 files changed, 123 insertions(+), 38 deletions(-) diff --git a/api/.env.example b/api/.env.example index 3f88fb3cdf5b09..468130b1628e9e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 diff --git a/api/app.py b/api/app.py index a251ef5f0f72c3..52dd492225339f 100644 --- a/api/app.py +++ b/api/app.py @@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") - logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) if logged_in_account: contexts.tenant_id.set(logged_in_account.current_tenant_id) return logged_in_account diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 93dbc1367f394c..a3334d16345e96 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings): ) -class OAuthConfig(BaseSettings): +class AuthConfig(BaseSettings): """ - Configuration for OAuth authentication + Configuration for authentication and OAuth """ OAUTH_REDIRECT_PATH: str = Field( @@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings): ) GITHUB_CLIENT_ID: Optional[str] = Field( - description="GitHub OAuth client secret", + description="GitHub OAuth client ID", default=None, ) @@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings): default=None, ) + ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field( + description="Expiration time for access tokens in minutes", + default=60, + ) + class ModerationConfig(BaseSettings): """ @@ -607,6 +612,7 @@ def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, + AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, DataSetConfig, @@ -621,14 +627,13 @@ class FeatureConfig( MailConfig, ModelLoadBalanceConfig, ModerationConfig, - OAuthConfig, + PositionConfig, RagEtlConfig, SecurityConfig, ToolConfig, UpdateConfig, WorkflowConfig, WorkspaceConfig, - PositionConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 62837af2b9b0eb..18a7b2316660c6 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -7,7 +7,7 @@ import services from controllers.console import api from controllers.console.setup import setup_required -from libs.helper import email, get_remote_ip +from libs.helper import email, extract_remote_ip from libs.password import valid_password from models.account import Account from services.account_service import AccountService, TenantService @@ -40,17 +40,16 @@ def post(self): "data": "workspace not found, please contact system admin to invite you to join in a workspace", } - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - return {"result": "success", "data": token} + return {"result": "success", "data": token_pair.model_dump()} class LogoutApi(Resource): @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get("Authorization", "").split(" ")[1] - AccountService.logout(account=account, token=token) + AccountService.logout(account=account) flask_login.logout_user() return {"result": "success"} @@ -106,5 +105,19 @@ def get(self): return {"result": "success"} +class RefreshTokenApi(Resource): + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("refresh_token", type=str, required=True, location="json") + args = parser.parse_args() + + try: + new_token_pair = AccountService.refresh_token(args["refresh_token"]) + return {"result": "success", "data": new_token_pair.model_dump()} + except Exception as e: + return {"result": "fail", "data": str(e)}, 401 + + api.add_resource(LoginApi, "/login") api.add_resource(LogoutApi, "/logout") +api.add_resource(RefreshTokenApi, "/refresh-token") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ad0c0580aeaba4..c5909b8c1092e3 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -9,7 +9,7 @@ from configs import dify_config from constants.languages import languages from extensions.ext_database import db -from libs.helper import get_remote_ip +from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService @@ -81,9 +81,14 @@ def get(self, provider: str): TenantService.create_owner_tenant_if_not_exist(account) - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token_pair = AccountService.login( + account=account, + ip_address=extract_remote_ip(request), + ) - return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") + return redirect( + f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + ) def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 46b4ef5d87a8a0..15a4af118b05be 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import StrLen, email, get_remote_ip +from libs.helper import StrLen, email, extract_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -46,7 +46,7 @@ def post(self): # setup RegisterService.setup( - email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) ) return {"result": "success"}, 201 diff --git a/api/libs/helper.py b/api/libs/helper.py index 9c3a1ff04d1f8f..d8a8e7f4118d8d 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -162,7 +162,7 @@ def generate_string(n): return result -def get_remote_ip(request) -> str: +def extract_remote_ip(request) -> str: if request.headers.get("CF-Connecting-IP"): return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): diff --git a/api/services/account_service.py b/api/services/account_service.py index 05b505f8a62184..eda6011aef5164 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -7,6 +7,7 @@ from hashlib import sha256 from typing import Any, Optional +from pydantic import BaseModel from sqlalchemy import func from werkzeug.exceptions import Unauthorized @@ -49,9 +50,39 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task +class TokenPair(BaseModel): + access_token: str + refresh_token: str + + +REFRESH_TOKEN_PREFIX = "refresh_token:" +ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" +REFRESH_TOKEN_EXPIRY = timedelta(days=30) + + class AccountService: reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) + @staticmethod + def _get_refresh_token_key(refresh_token: str) -> str: + return f"{REFRESH_TOKEN_PREFIX}{refresh_token}" + + @staticmethod + def _get_account_refresh_token_key(account_id: str) -> str: + return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" + + @staticmethod + def _store_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) + redis_client.setex( + AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token + ) + + @staticmethod + def _delete_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) + redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) + @staticmethod def load_user(user_id: str) -> None | Account: account = Account.query.filter_by(id=user_id).first() @@ -61,9 +92,7 @@ def load_user(user_id: str) -> None | Account: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( - account_id=account.id, current=True - ).first() + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: @@ -84,10 +113,12 @@ def load_user(user_id: str) -> None | Account: return account @staticmethod - def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): + def get_account_jwt_token(account: Account) -> str: + exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, - "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, + "exp": exp, "iss": dify_config.EDITION, "sub": "Console API Passport", } @@ -213,7 +244,7 @@ def update_account(account, **kwargs): return account @staticmethod - def update_last_login(account: Account, *, ip_address: str) -> None: + def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) account.last_login_ip = ip_address @@ -221,22 +252,45 @@ def update_last_login(account: Account, *, ip_address: str) -> None: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None): + def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: if ip_address: - AccountService.update_last_login(account, ip_address=ip_address) - exp = timedelta(days=30) - token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) - return token + AccountService.update_login_info(account=account, ip_address=ip_address) + + access_token = AccountService.get_account_jwt_token(account=account) + refresh_token = _generate_refresh_token() + + AccountService._store_refresh_token(refresh_token, account.id) + + return TokenPair(access_token=access_token, refresh_token=refresh_token) @staticmethod - def logout(*, account: Account, token: str): - redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) + def logout(*, account: Account) -> None: + refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) + if refresh_token: + AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod - def load_logged_in_account(*, account_id: str, token: str): - if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): - return None + def refresh_token(refresh_token: str) -> TokenPair: + # Verify the refresh token + account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) + if not account_id: + raise ValueError("Invalid refresh token") + + account = AccountService.load_user(account_id.decode("utf-8")) + if not account: + raise ValueError("Invalid account") + + # Generate new access token and refresh token + new_access_token = AccountService.get_account_jwt_token(account) + new_refresh_token = _generate_refresh_token() + + AccountService._delete_refresh_token(refresh_token, account.id) + AccountService._store_refresh_token(new_refresh_token, account.id) + + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + + @staticmethod + def load_logged_in_account(*, account_id: str): return AccountService.load_user(account_id) @classmethod @@ -258,10 +312,6 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") -def _get_login_cache_key(*, account_id: str, token: str): - return f"account_login:{account_id}:{token}" - - class TenantService: @staticmethod def create_tenant(name: str) -> Tenant: @@ -698,3 +748,8 @@ def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> invitation = json.loads(data) return invitation + + +def _generate_refresh_token(length: int = 64): + token = secrets.token_hex(length) + return token diff --git a/docker/.env.example b/docker/.env.example index c4eae46cb0215a..c7b4f38d2ec112 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -91,6 +91,9 @@ MIGRATION_ENABLED=true # The default value is 300 seconds. FILES_ACCESS_TIMEOUT=300 +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c046c17ef8f2b2..cb6ecba2791db1 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} + ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} From ea6734f55009e6a031ff19255a7bc630728818da Mon Sep 17 00:00:00 2001 From: NFish Date: Sat, 12 Oct 2024 23:49:18 +0800 Subject: [PATCH 14/25] Feat/new account page (#9236) --- web/app/account/account-page/index.module.css | 9 + web/app/account/account-page/index.tsx | 304 ++++++++++++++++++ web/app/account/avatar.tsx | 94 ++++++ web/app/account/header.tsx | 37 +++ web/app/account/layout.tsx | 40 +++ web/app/account/page.tsx | 7 + .../header/account-dropdown/index.tsx | 11 +- .../header/account-setting/index.tsx | 22 +- web/app/components/header/header-wrapper.tsx | 2 +- web/i18n/en-US/common.ts | 5 +- web/i18n/zh-Hans/common.ts | 6 +- 11 files changed, 511 insertions(+), 26 deletions(-) create mode 100644 web/app/account/account-page/index.module.css create mode 100644 web/app/account/account-page/index.tsx create mode 100644 web/app/account/avatar.tsx create mode 100644 web/app/account/header.tsx create mode 100644 web/app/account/layout.tsx create mode 100644 web/app/account/page.tsx diff --git a/web/app/account/account-page/index.module.css b/web/app/account/account-page/index.module.css new file mode 100644 index 00000000000000..949d1257e9820c --- /dev/null +++ b/web/app/account/account-page/index.module.css @@ -0,0 +1,9 @@ +.modal { + padding: 24px 32px !important; + width: 400px !important; +} + +.bg { + background: linear-gradient(180deg, rgba(217, 45, 32, 0.05) 0%, rgba(217, 45, 32, 0.00) 24.02%), #F9FAFB; +} + diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx new file mode 100644 index 00000000000000..53f7692e6c8f5b --- /dev/null +++ b/web/app/account/account-page/index.tsx @@ -0,0 +1,304 @@ +'use client' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { useContext } from 'use-context-selector' +import s from './index.module.css' +import Collapse from '@/app/components/header/account-setting/collapse' +import type { IItem } from '@/app/components/header/account-setting/collapse' +import Modal from '@/app/components/base/modal' +import Confirm from '@/app/components/base/confirm' +import Button from '@/app/components/base/button' +import { updateUserProfile } from '@/service/common' +import { useAppContext } from '@/context/app-context' +import { ToastContext } from '@/app/components/base/toast' +import AppIcon from '@/app/components/base/app-icon' +import Avatar from '@/app/components/base/avatar' +import { IS_CE_EDITION } from '@/config' + +const titleClassName = ` + text-sm font-medium text-gray-900 +` +const descriptionClassName = ` + mt-1 text-xs font-normal text-gray-500 +` +const inputClassName = ` + mt-2 w-full px-3 py-2 bg-gray-100 rounded + text-sm font-normal text-gray-800 +` + +const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +export default function AccountPage() { + const { t } = useTranslation() + const { mutateUserProfile, userProfile, apps } = useAppContext() + const { notify } = useContext(ToastContext) + const [editNameModalVisible, setEditNameModalVisible] = useState(false) + const [editName, setEditName] = useState('') + const [editing, setEditing] = useState(false) + const [editPasswordModalVisible, setEditPasswordModalVisible] = useState(false) + const [currentPassword, setCurrentPassword] = useState('') + const [password, setPassword] = useState('') + const [confirmPassword, setConfirmPassword] = useState('') + const [showDeleteAccountModal, setShowDeleteAccountModal] = useState(false) + + const handleEditName = () => { + setEditNameModalVisible(true) + setEditName(userProfile.name) + } + const handleSaveName = async () => { + try { + setEditing(true) + await updateUserProfile({ url: 'account/name', body: { name: editName } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + mutateUserProfile() + setEditNameModalVisible(false) + setEditing(false) + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + setEditNameModalVisible(false) + setEditing(false) + } + } + + const showErrorMessage = (message: string) => { + notify({ + type: 'error', + message, + }) + } + const valid = () => { + if (!password.trim()) { + showErrorMessage(t('login.error.passwordEmpty')) + return false + } + if (!validPassword.test(password)) { + showErrorMessage(t('login.error.passwordInvalid')) + return false + } + if (password !== confirmPassword) { + showErrorMessage(t('common.account.notEqual')) + return false + } + + return true + } + const resetPasswordForm = () => { + setCurrentPassword('') + setPassword('') + setConfirmPassword('') + } + const handleSavePassword = async () => { + if (!valid()) + return + try { + setEditing(true) + await updateUserProfile({ + url: 'account/password', + body: { + password: currentPassword, + new_password: password, + repeat_new_password: confirmPassword, + }, + }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + mutateUserProfile() + setEditPasswordModalVisible(false) + resetPasswordForm() + setEditing(false) + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + setEditPasswordModalVisible(false) + setEditing(false) + } + } + + const renderAppItem = (item: IItem) => { + return ( +
+
+ +
+
{item.name}
+
+ ) + } + + return ( + <> +
+

{t('common.account.myAccount')}

+
+
+ +
+

{userProfile.name}

+

{userProfile.email}

+
+
+
+
{t('common.account.name')}
+
+
+ {userProfile.name} +
+
+ {t('common.operation.edit')} +
+
+
+
+
{t('common.account.email')}
+
+
+ {userProfile.email} +
+
+
+ { + IS_CE_EDITION && ( +
+
+
{t('common.account.password')}
+
{t('common.account.passwordTip')}
+
+ +
+ ) + } +
+
+
{t('common.account.langGeniusAccount')}
+
{t('common.account.langGeniusAccountTip')}
+ {!!apps.length && ( + ({ key: app.id, name: app.name }))} + renderItem={renderAppItem} + wrapperClassName='mt-2' + /> + )} + {!IS_CE_EDITION && } +
+ { + editNameModalVisible && ( + setEditNameModalVisible(false)} + className={s.modal} + > +
{t('common.account.editName')}
+
{t('common.account.name')}
+ setEditName(e.target.value)} + /> +
+ + +
+
+ ) + } + { + editPasswordModalVisible && ( + { + setEditPasswordModalVisible(false) + resetPasswordForm() + }} + className={s.modal} + > +
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+ {userProfile.is_password_set && ( + <> +
{t('common.account.currentPassword')}
+ setCurrentPassword(e.target.value)} + /> + + )} +
+ {userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')} +
+ setPassword(e.target.value)} + /> +
{t('common.account.confirmPassword')}
+ setConfirmPassword(e.target.value)} + /> +
+ + +
+
+ ) + } + { + showDeleteAccountModal && ( + setShowDeleteAccountModal(false)} + onConfirm={() => setShowDeleteAccountModal(false)} + showCancel={false} + type='warning' + title={t('common.account.delete')} + content={ + <> +
+ {t('common.account.deleteTip')} +
+ +
{`${t('common.account.delete')}: ${userProfile.email}`}
+ + } + confirmText={t('common.operation.ok') as string} + /> + ) + } + + ) +} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx new file mode 100644 index 00000000000000..29bd0cb5a581c6 --- /dev/null +++ b/web/app/account/avatar.tsx @@ -0,0 +1,94 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { Fragment } from 'react' +import { useRouter } from 'next/navigation' +import { Menu, Transition } from '@headlessui/react' +import Avatar from '@/app/components/base/avatar' +import { logout } from '@/service/common' +import { useAppContext } from '@/context/app-context' +import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' + +export type IAppSelector = { + isMobile: boolean +} + +export default function AppSelector() { + const router = useRouter() + const { t } = useTranslation() + const { userProfile } = useAppContext() + + const handleLogout = async () => { + await logout({ + url: '/logout', + params: {}, + }) + + if (localStorage?.getItem('console_token')) + localStorage.removeItem('console_token') + + router.push('/signin') + } + + return ( + + { + ({ open }) => ( + <> +
+ + + +
+ + + +
+
+
+
{userProfile.name}
+
{userProfile.email}
+
+ +
+
+
+ +
handleLogout()}> +
+ +
{t('common.userProfile.logout')}
+
+
+
+
+
+ + ) + } +
+ ) +} diff --git a/web/app/account/header.tsx b/web/app/account/header.tsx new file mode 100644 index 00000000000000..694533e5ab7cb6 --- /dev/null +++ b/web/app/account/header.tsx @@ -0,0 +1,37 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' +import { useRouter } from 'next/navigation' +import Button from '../components/base/button' +import Avatar from './avatar' +import LogoSite from '@/app/components/base/logo/logo-site' + +const Header = () => { + const { t } = useTranslation() + const router = useRouter() + + const back = () => { + router.back() + } + return ( +
+
+
+ +
+
+

{t('common.account.account')}

+
+
+ +
+ +
+
+ ) +} +export default Header diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx new file mode 100644 index 00000000000000..5aa8b05cbfd07b --- /dev/null +++ b/web/app/account/layout.tsx @@ -0,0 +1,40 @@ +import React from 'react' +import type { ReactNode } from 'react' +import Header from './header' +import SwrInitor from '@/app/components/swr-initor' +import { AppContextProvider } from '@/context/app-context' +import GA, { GaType } from '@/app/components/base/ga' +import HeaderWrapper from '@/app/components/header/header-wrapper' +import { EventEmitterContextProvider } from '@/context/event-emitter' +import { ProviderContextProvider } from '@/context/provider-context' +import { ModalContextProvider } from '@/context/modal-context' + +const Layout = ({ children }: { children: ReactNode }) => { + return ( + <> + + + + + + + +
+ +
+ {children} +
+ + + + + + + ) +} + +export const metadata = { + title: 'Dify', +} + +export default Layout diff --git a/web/app/account/page.tsx b/web/app/account/page.tsx new file mode 100644 index 00000000000000..bb7e7f7feb1840 --- /dev/null +++ b/web/app/account/page.tsx @@ -0,0 +1,7 @@ +import AccountPage from './account-page' + +export default function Account() { + return
+ +
+} diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 03157ed7cb4e96..712906ebae3815 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -107,7 +107,16 @@ export default function AppSelector({ isMobile }: IAppSelector) {
-
setShowAccountSettingModal({ payload: 'account' })}> + +
{t('common.account.account')}
+ + + + +
setShowAccountSettingModal({ payload: 'members' })}>
{t('common.userProfile.settings')}
diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 253b9f1b4c2cd2..d829f6b77b0cc9 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -2,10 +2,6 @@ import { useTranslation } from 'react-i18next' import { useEffect, useRef, useState } from 'react' import { - RiAccountCircleFill, - RiAccountCircleLine, - RiApps2AddFill, - RiApps2AddLine, RiBox3Fill, RiBox3Line, RiCloseLine, @@ -21,9 +17,7 @@ import { RiPuzzle2Line, RiTranslate2, } from '@remixicon/react' -import AccountPage from './account-page' import MembersPage from './members-page' -import IntegrationsPage from './Integrations-page' import LanguagePage from './language-page' import ApiBasedExtensionPage from './api-based-extension-page' import DataSourcePage from './data-source-page' @@ -60,7 +54,7 @@ type GroupItem = { export default function AccountSetting({ onCancel, - activeTab = 'account', + activeTab = 'members', }: IAccountSettingProps) { const [activeMenu, setActiveMenu] = useState(activeTab) const { t } = useTranslation() @@ -125,18 +119,6 @@ export default function AccountSetting({ key: 'account-group', name: t('common.settings.accountGroup'), items: [ - { - key: 'account', - name: t('common.settings.account'), - icon: , - activeIcon: , - }, - { - key: 'integrations', - name: t('common.settings.integrations'), - icon: , - activeIcon: , - }, { key: 'language', name: t('common.settings.language'), @@ -217,10 +199,8 @@ export default function AccountSetting({
- {activeMenu === 'account' && } {activeMenu === 'members' && } {activeMenu === 'billing' && } - {activeMenu === 'integrations' && } {activeMenu === 'language' && } {activeMenu === 'provider' && } {activeMenu === 'data-source' && } diff --git a/web/app/components/header/header-wrapper.tsx b/web/app/components/header/header-wrapper.tsx index 205a379a903018..360cf8e5607ba1 100644 --- a/web/app/components/header/header-wrapper.tsx +++ b/web/app/components/header/header-wrapper.tsx @@ -11,7 +11,7 @@ const HeaderWrapper = ({ children, }: HeaderWrapperProps) => { const pathname = usePathname() - const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools'].includes(pathname) + const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools', '/account'].includes(pathname) return (
Date: Sat, 12 Oct 2024 23:58:41 +0800 Subject: [PATCH 15/25] chore: add abstract decorator and output log when query embedding fails (#9264) --- api/core/embedding/cached_embedding.py | 7 ++++++- api/core/rag/datasource/keyword/keyword_base.py | 2 ++ .../datasource/vdb/elasticsearch/elasticsearch_vector.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 75219051cd30cd..31d2171e7267df 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -5,6 +5,7 @@ import numpy as np from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.embedding.embedding_constant import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -110,6 +111,8 @@ def embed_query(self, text: str) -> list[float]: embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() except Exception as ex: + if dify_config.DEBUG: + logging.exception(f"Failed to embed query text: {ex}") raise ex try: @@ -122,6 +125,8 @@ def embed_query(self, text: str) -> list[float]: encoded_str = encoded_vector.decode("utf-8") redis_client.setex(embedding_cache_key, 600, encoded_str) except Exception as ex: - logging.exception("Failed to add embedding to redis %s", ex) + if dify_config.DEBUG: + logging.exception("Failed to add embedding to redis %s", ex) + raise ex return embedding_results diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index 4b9ec460e61fe8..be00687abd5025 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -27,9 +27,11 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + @abstractmethod def delete(self) -> None: raise NotImplementedError + @abstractmethod def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f585e12b2e99c3..66bc31a4bfaffa 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -77,7 +77,7 @@ def _check_version(self): raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return "elasticsearch" + return VectorType.ELASTICSEARCH def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) From 9275760599abf6df762b6f8babf7eba54554119d Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Sun, 13 Oct 2024 09:44:53 +0800 Subject: [PATCH 16/25] chore: add baidu-obs and supabase for .env.example (#9289) --- docker/.env.example | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/docker/.env.example b/docker/.env.example index c7b4f38d2ec112..969deadf679212 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -264,7 +264,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=* # ------------------------------ # The type of storage to use for storing user files. -# Supported values are `local` and `s3` and `azure-blob` and `google-storage` and `tencent-cos` and `huawei-obs` +# Supported values are `local` , `s3` , `azure-blob` , `google-storage`, `tencent-cos`, `huawei-obs`, `volcengine-tos`, `baidu-obs`, `supabase` # Default: `local` STORAGE_TYPE=local @@ -344,6 +344,24 @@ VOLCENGINE_TOS_ENDPOINT=your-server-url # The region of the Volcengine TOS service. VOLCENGINE_TOS_REGION=your-region +# Baidu OBS Storage Configuration +# The name of the Baidu OBS bucket to use for storing files. +BAIDU_OBS_BUCKET_NAME=your-bucket-name +# The secret key to use for authenticating with the Baidu OBS service. +BAIDU_OBS_SECRET_KEY=your-secret-key +# The access key to use for authenticating with the Baidu OBS service. +BAIDU_OBS_ACCESS_KEY=your-access-key +# The endpoint of the Baidu OBS service. +BAIDU_OBS_ENDPOINT=your-server-url + +# Supabase Storage Configuration +# The name of the Supabase bucket to use for storing files. +SUPABASE_BUCKET_NAME=your-bucket-name +# The api key to use for authenticating with the Supabase service. +SUPABASE_API_KEY=your-access-key +# The project endpoint url of the Supabase service. +SUPABASE_URL=your-server-url + # ------------------------------ # Vector Database Configuration # ------------------------------ From 1ec83e4969a2a78cd5a2e563a3cc8815328562ef Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 14:56:26 +0800 Subject: [PATCH 17/25] chore: translate i18n files (#9288) Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> --- web/i18n/de-DE/common.ts | 3 +++ web/i18n/es-ES/common.ts | 3 +++ web/i18n/fa-IR/common.ts | 3 +++ web/i18n/fr-FR/common.ts | 3 +++ web/i18n/hi-IN/common.ts | 3 +++ web/i18n/it-IT/common.ts | 3 +++ web/i18n/ja-JP/common.ts | 3 +++ web/i18n/ko-KR/common.ts | 3 +++ web/i18n/pl-PL/common.ts | 3 +++ web/i18n/pt-BR/common.ts | 3 +++ web/i18n/ro-RO/common.ts | 3 +++ web/i18n/ru-RU/common.ts | 3 +++ web/i18n/tr-TR/common.ts | 3 +++ web/i18n/uk-UA/common.ts | 3 +++ web/i18n/vi-VN/common.ts | 3 +++ web/i18n/zh-Hans/common.ts | 1 + web/i18n/zh-Hant/common.ts | 3 +++ 17 files changed, 49 insertions(+) diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index 6ea06bc8b1c72f..8b221ca3bbbf4a 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Konto löschen', deleteTip: 'Wenn Sie Ihr Konto löschen, werden alle Ihre Daten dauerhaft gelöscht und können nicht wiederhergestellt werden.', deleteConfirmTip: 'Zur Bestätigung senden Sie bitte Folgendes von Ihrer registrierten E-Mail-Adresse an ', + myAccount: 'Mein Konto', + studio: 'Dify Studio', + account: 'Konto', }, members: { team: 'Team', diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index 59a05f63d8a3e2..748c9d152d4c19 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -171,6 +171,9 @@ const translation = { delete: 'Eliminar cuenta', deleteTip: 'Eliminar tu cuenta borrará permanentemente todos tus datos y no se podrán recuperar.', deleteConfirmTip: 'Para confirmar, por favor envía lo siguiente desde tu correo electrónico registrado a ', + account: 'Cuenta', + myAccount: 'Mi Cuenta', + studio: 'Estudio Dify', }, members: { team: 'Equipo', diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index c75ab11a63cd89..a369a0ba5ebb14 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -171,6 +171,9 @@ const translation = { delete: 'حذف حساب کاربری', deleteTip: 'حذف حساب کاربری شما تمام داده‌های شما را به طور دائمی پاک می‌کند و قابل بازیابی نیست.', deleteConfirmTip: 'برای تأیید، لطفاً موارد زیر را از ایمیل ثبت‌نام شده خود به این آدرس ارسال کنید ', + account: 'حساب', + myAccount: 'حساب من', + studio: 'استودیو Dify', }, members: { team: 'تیم', diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index c4fed4405dc4fe..0cd301aed2c5e6 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Supprimer le compte', deleteTip: 'La suppression de votre compte effacera définitivement toutes vos données et elles ne pourront pas être récupérées.', deleteConfirmTip: 'Pour confirmer, veuillez envoyer ce qui suit depuis votre adresse e-mail enregistrée à ', + myAccount: 'Mon compte', + account: 'Compte', + studio: 'Dify Studio', }, members: { team: 'Équipe', diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index 256cb9d426f238..224090437e4496 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -177,6 +177,9 @@ const translation = { deleteConfirmTip: 'पुष्टि करने के लिए, कृपया अपने पंजीकृत ईमेल से निम्नलिखित भेजें', delete: 'खाता हटाएं', deleteTip: 'अपना खाता हटाने से आपका सारा डेटा स्थायी रूप से मिट जाएगा और इसे पुनर्प्राप्त नहीं किया जा सकता है।', + account: 'खाता', + studio: 'डिफाई स्टूडियो', + myAccount: 'मेरा खाता', }, members: { team: 'टीम', diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index aa675bb4718299..5c180a8b697107 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -179,6 +179,9 @@ const translation = { 'Eliminando il tuo account cancellerai permanentemente tutti i tuoi dati e non sarà possibile recuperarli.', deleteConfirmTip: 'Per confermare, invia il seguente messaggio dalla tua email registrata a ', + myAccount: 'Il mio account', + account: 'Conto', + studio: 'Dify Studio', }, members: { team: 'Team', diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index e2517a619dc97b..bd50e68b0949e9 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -171,6 +171,9 @@ const translation = { delete: 'アカウントを削除', deleteTip: 'アカウントを削除すると、すべてのデータが完全に消去され、復元できなくなります。', deleteConfirmTip: '確認のため、登録したメールから次の内容をに送信してください ', + account: 'アカウント', + myAccount: 'マイアカウント', + studio: 'Difyスタジオ', }, members: { team: 'チーム', diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index 8ef55da3f7153e..d70b7ebb108d47 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -163,6 +163,9 @@ const translation = { delete: '계정 삭제', deleteTip: '계정을 삭제하면 모든 데이터가 영구적으로 지워지며 복구할 수 없습니다.', deleteConfirmTip: '확인하려면 등록된 이메일에서 다음 내용을 로 보내주세요 ', + myAccount: '내 계정', + studio: '디파이 스튜디오', + account: '계좌', }, members: { team: '팀', diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index 91f5fb28992966..b0706787855d36 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -173,6 +173,9 @@ const translation = { delete: 'Usuń konto', deleteTip: 'Usunięcie konta spowoduje trwałe usunięcie wszystkich danych i nie będzie można ich odzyskać.', deleteConfirmTip: 'Aby potwierdzić, wyślij następujące informacje z zarejestrowanego adresu e-mail na adres ', + myAccount: 'Moje konto', + studio: 'Dify Studio', + account: 'Rachunek', }, members: { team: 'Zespół', diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index f9e9eb78889e1e..9343fdf5605e7e 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Excluir conta', deleteTip: 'Excluir sua conta apagará permanentemente todos os seus dados e eles não poderão ser recuperados.', deleteConfirmTip: 'Para confirmar, envie o seguinte do seu e-mail registrado para ', + myAccount: 'Minha Conta', + account: 'Conta', + studio: 'Estúdio Dify', }, members: { team: 'Equipe', diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index 1fd87781062138..dc3bfcc45a845a 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Șterge contul', deleteTip: 'Ștergerea contului vă va șterge definitiv toate datele și nu pot fi recuperate.', deleteConfirmTip: 'Pentru a confirma, trimiteți următoarele din e-mailul înregistrat la ', + account: 'Cont', + studio: 'Dify Studio', + myAccount: 'Contul meu', }, members: { team: 'Echipă', diff --git a/web/i18n/ru-RU/common.ts b/web/i18n/ru-RU/common.ts index 82e3471e607dd6..a829fb27b11517 100644 --- a/web/i18n/ru-RU/common.ts +++ b/web/i18n/ru-RU/common.ts @@ -171,6 +171,9 @@ const translation = { delete: 'Удалить учетную запись', deleteTip: 'Удаление вашей учетной записи приведет к безвозвратному удалению всех ваших данных, и их невозможно будет восстановить.', deleteConfirmTip: 'Для подтверждения, пожалуйста, отправьте следующее с вашего зарегистрированного адреса электронной почты на ', + account: 'Счет', + studio: 'Студия Dify', + myAccount: 'Моя учетная запись', }, members: { team: 'Команда', diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index a41925cd2002ae..dc4b1cccbab5a6 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -171,6 +171,9 @@ const translation = { delete: 'Hesabı Sil', deleteTip: 'Hesabınızı silmek tüm verilerinizi kalıcı olarak siler ve geri alınamaz.', deleteConfirmTip: 'Onaylamak için, kayıtlı e-postanızdan şu adrese e-posta gönderin: ', + account: 'Hesap', + myAccount: 'Hesabım', + studio: 'Dify Stüdyo', }, members: { team: 'Takım', diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index cc70772be33fc6..ef0bc55203787e 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Видалити обліковий запис', deleteTip: 'Видалення вашого облікового запису призведе до остаточного видалення всіх ваших даних, і їх неможливо буде відновити.', deleteConfirmTip: 'Щоб підтвердити, будь ласка, надішліть наступне з вашої зареєстрованої електронної пошти на ', + account: 'Рахунок', + studio: 'Студія Dify', + myAccount: 'Особистий кабінет', }, members: { team: 'Команда', diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index 252fa7e1dfa955..5336ec4f667f45 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -167,6 +167,9 @@ const translation = { delete: 'Xóa tài khoản', deleteTip: 'Xóa tài khoản của bạn sẽ xóa vĩnh viễn tất cả dữ liệu của bạn và không thể khôi phục được.', deleteConfirmTip: 'Để xác nhận, vui lòng gửi thông tin sau từ email đã đăng ký của bạn tới ', + studio: 'Dify Studio', + myAccount: 'Tài khoản của tôi', + account: 'Tài khoản', }, members: { team: 'Nhóm', diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index fe2b3bf92d0b28..3c9b61f56b0e1c 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -142,6 +142,7 @@ const translation = { settings: { accountGroup: '通用', workplaceGroup: '工作空间', + account: '我的账户', members: '成员', billing: '账单', integrations: '集成', diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index 8cd51b1991f1a2..e43b49bd3c6189 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -167,6 +167,9 @@ const translation = { delete: '刪除帳戶', deleteTip: '刪除您的帳戶將永久刪除您的所有資料並且無法恢復。', deleteConfirmTip: '請將以下內容從您的註冊電子郵件發送至 ', + account: '帳戶', + myAccount: '我的帳戶', + studio: 'Dify 工作室', }, members: { team: '團隊', From d15ba3939d39c14fd52cd2566bc3bf04cdf5f1f2 Mon Sep 17 00:00:00 2001 From: ice yao Date: Sun, 13 Oct 2024 21:26:05 +0800 Subject: [PATCH 18/25] Add Volcengine VikingDB as new vector provider (#9287) --- api/.env.example | 11 +- api/configs/middleware/__init__.py | 2 + api/configs/middleware/vdb/vikingdb_config.py | 37 +++ api/controllers/console/datasets/datasets.py | 2 + api/core/rag/datasource/vdb/vector_factory.py | 4 + api/core/rag/datasource/vdb/vector_type.py | 1 + .../rag/datasource/vdb/vikingdb/__init__.py | 0 .../vdb/vikingdb/vikingdb_vector.py | 239 ++++++++++++++++++ api/poetry.lock | 73 +++++- api/pyproject.toml | 1 + .../integration_tests/vdb/__mock/vikingdb.py | 215 ++++++++++++++++ .../vdb/vikingdb/__init__.py | 0 .../vdb/vikingdb/test_vikingdb.py | 37 +++ dev/pytest/pytest_vdb.sh | 3 +- docker/docker-compose.yaml | 5 + 15 files changed, 627 insertions(+), 3 deletions(-) create mode 100644 api/configs/middleware/vdb/vikingdb_config.py create mode 100644 api/core/rag/datasource/vdb/vikingdb/__init__.py create mode 100644 api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py create mode 100644 api/tests/integration_tests/vdb/__mock/vikingdb.py create mode 100644 api/tests/integration_tests/vdb/vikingdb/__init__.py create mode 100644 api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py diff --git a/api/.env.example b/api/.env.example index 468130b1628e9e..aa155003abd3f4 100644 --- a/api/.env.example +++ b/api/.env.example @@ -111,7 +111,7 @@ SUPABASE_URL=your-server-url WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb VECTOR_STORE=weaviate # Weaviate configuration @@ -220,6 +220,15 @@ BAIDU_VECTOR_DB_DATABASE=dify BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 +# ViKingDB configuration +VIKINGDB_ACCESS_KEY=your-ak +VIKINGDB_SECRET_KEY=your-sk +VIKINGDB_REGION=cn-shanghai +VIKINGDB_HOST=api-vikingdb.xxx.volces.com +VIKINGDB_SCHEMA=http +VIKINGDB_CONNECTION_TIMEOUT=30 +VIKINGDB_SOCKET_TIMEOUT=30 + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 25f3df6dde41d7..fa7f41d630965a 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -28,6 +28,7 @@ from configs.middleware.vdb.relyt_config import RelytConfig from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig +from configs.middleware.vdb.vikingdb_config import VikingDBConfig from configs.middleware.vdb.weaviate_config import WeaviateConfig @@ -243,5 +244,6 @@ class MiddlewareConfig( WeaviateConfig, ElasticsearchConfig, InternalTestConfig, + VikingDBConfig, ): pass diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py new file mode 100644 index 00000000000000..5ad98d898a16e3 --- /dev/null +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -0,0 +1,37 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VikingDBConfig(BaseModel): + """ + Configuration for connecting to Volcengine VikingDB. + Refer to the following documentation for details on obtaining credentials: + https://www.volcengine.com/docs/6291/65568 + """ + + VIKINGDB_ACCESS_KEY: Optional[str] = Field( + default=None, description="The Access Key provided by Volcengine VikingDB for API authentication." + ) + VIKINGDB_SECRET_KEY: Optional[str] = Field( + default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication." + ) + VIKINGDB_REGION: Optional[str] = Field( + default="cn-shanghai", + description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", + ) + VIKINGDB_HOST: Optional[str] = Field( + default="api-vikingdb.mlp.cn-shanghai.volces.com", + description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ + 'api-vikingdb.mlp.cn-shanghai.volces.com')", + ) + VIKINGDB_SCHEME: Optional[str] = Field( + default="http", + description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", + ) + VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field( + default=30, description="The connection timeout of the Volcengine VikingDB service." + ) + VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field( + default=30, description="The socket timeout of the Volcengine VikingDB service." + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 102089bf071ac2..6583356d23c9eb 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -618,6 +618,7 @@ def get(self): | VectorType.TENCENT | VectorType.PGVECTO_RS | VectorType.BAIDU + | VectorType.VIKINGDB ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( @@ -655,6 +656,7 @@ def get(self, vector_type): | VectorType.TENCENT | VectorType.PGVECTO_RS | VectorType.BAIDU + | VectorType.VIKINGDB ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 1f4a4d44a23eea..873b2890277d63 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -107,6 +107,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory return BaiduVectorFactory + case VectorType.VIKINGDB: + from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory + + return VikingDBVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 996ff48615c901..b4d604a080899b 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -17,3 +17,4 @@ class VectorType(str, Enum): ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" BAIDU = "baidu" + VIKINGDB = "vikingdb" diff --git a/api/core/rag/datasource/vdb/vikingdb/__init__.py b/api/core/rag/datasource/vdb/vikingdb/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py new file mode 100644 index 00000000000000..22d0e92586fcb7 --- /dev/null +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -0,0 +1,239 @@ +import json +from typing import Any + +from pydantic import BaseModel +from volcengine.viking_db import ( + Data, + DistanceType, + Field, + FieldType, + IndexType, + QuantType, + VectorIndexParams, + VikingDBService, +) + +from configs import dify_config +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field as vdb_Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class VikingDBConfig(BaseModel): + access_key: str + secret_key: str + host: str + region: str + scheme: str + connection_timeout: int + socket_timeout: int + index_type: str = IndexType.HNSW + distance: str = DistanceType.L2 + quant: str = QuantType.Float + + +class VikingDBVector(BaseVector): + def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig): + super().__init__(collection_name) + self._group_id = group_id + self._client_config = config + self._index_name = f"{self._collection_name}_idx" + self._client = VikingDBService( + host=config.host, + region=config.region, + scheme=config.scheme, + connection_timeout=config.connection_timeout, + socket_timeout=config.socket_timeout, + ak=config.access_key, + sk=config.secret_key, + ) + + def _has_collection(self) -> bool: + try: + self._client.get_collection(self._collection_name) + except Exception: + return False + return True + + def _has_index(self) -> bool: + try: + self._client.get_index(self._collection_name, self._index_name) + except Exception: + return False + return True + + def _create_collection(self, dimension: int): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + + if not self._has_collection(): + fields = [ + Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension), + ] + + self._client.create_collection( + collection_name=self._collection_name, + fields=fields, + description="Collection For Dify", + ) + + if not self._has_index(): + vector_index = VectorIndexParams( + distance=self._client_config.distance, + index_type=self._client_config.index_type, + quant=self._client_config.quant, + ) + + self._client.create_index( + collection_name=self._collection_name, + index_name=self._index_name, + vector_index=vector_index, + partition_by=vdb_Field.GROUP_KEY.value, + description="Index For Dify", + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def get_type(self) -> str: + return VectorType.VIKINGDB + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + self.add_texts(texts, embeddings, **kwargs) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + page_contents = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + docs = [] + + for i, page_content in enumerate(page_contents): + metadata = {} + if metadatas is not None: + for key, val in metadatas[i].items(): + metadata[key] = val + doc = Data( + { + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, + vdb_Field.CONTENT_KEY.value: page_content, + vdb_Field.METADATA_KEY.value: json.dumps(metadata), + vdb_Field.GROUP_KEY.value: self._group_id, + } + ) + docs.append(doc) + + self._client.get_collection(self._collection_name).upsert_data(docs) + + def text_exists(self, id: str) -> bool: + docs = self._client.get_collection(self._collection_name).fetch_data(id) + not_exists_str = "data does not exist" + if docs is not None and not_exists_str not in docs.fields.get("message", ""): + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + self._client.get_collection(self._collection_name).delete_data(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + # Note: Metadata field value is an dict, but vikingdb field + # not support json type + results = self._client.get_index(self._collection_name, self._index_name).search( + filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]}, + # max value is 5000 + limit=5000, + ) + + if not results: + return [] + + ids = [] + for result in results: + metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + if metadata is not None: + metadata = json.loads(metadata) + if metadata.get(key) == value: + ids.append(result.id) + return ids + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + self.delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + results = self._client.get_index(self._collection_name, self._index_name).search_by_vector( + query_vector, limit=kwargs.get("top_k", 50) + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(results, score_threshold) + + def _get_search_res(self, results, score_threshold): + if len(results) == 0: + return [] + + docs = [] + for result in results: + metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + if metadata is not None: + metadata = json.loads(metadata) + if result.score > score_threshold: + metadata["score"] = result.score + doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) + docs.append(doc) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + return [] + + def delete(self) -> None: + if self._has_index(): + self._client.drop_index(self._collection_name, self._index_name) + if self._has_collection(): + self._client.drop_collection(self._collection_name) + + +class VikingDBVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name)) + + if dify_config.VIKINGDB_ACCESS_KEY is None: + raise ValueError("VIKINGDB_ACCESS_KEY should not be None") + if dify_config.VIKINGDB_SECRET_KEY is None: + raise ValueError("VIKINGDB_SECRET_KEY should not be None") + if dify_config.VIKINGDB_HOST is None: + raise ValueError("VIKINGDB_HOST should not be None") + if dify_config.VIKINGDB_REGION is None: + raise ValueError("VIKINGDB_REGION should not be None") + if dify_config.VIKINGDB_SCHEME is None: + raise ValueError("VIKINGDB_SCHEME should not be None") + return VikingDBVector( + collection_name=collection_name, + group_id=dataset.id, + config=VikingDBConfig( + access_key=dify_config.VIKINGDB_ACCESS_KEY, + secret_key=dify_config.VIKINGDB_SECRET_KEY, + host=dify_config.VIKINGDB_HOST, + region=dify_config.VIKINGDB_REGION, + scheme=dify_config.VIKINGDB_SCHEME, + connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT, + socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT, + ), + ) diff --git a/api/poetry.lock b/api/poetry.lock index 6565db27ad5725..efefedfb21cd2a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2038,6 +2038,17 @@ packaging = ">=17.0" pandas = ">=0.24.2" pyarrow = ">=3.0.0" +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -3027,6 +3038,20 @@ files = [ docs = ["sphinx (>=4)", "sphinx-rtd-theme (>=1)"] tests = ["cython", "hypothesis", "mpmath", "pytest", "setuptools"] +[[package]] +name = "google" +version = "3.0.0" +description = "Python bindings to the Google search engine." +optional = false +python-versions = "*" +files = [ + {file = "google-3.0.0-py2.py3-none-any.whl", hash = "sha256:889cf695f84e4ae2c55fbc0cfdaf4c1e729417fa52ab1db0485202ba173e4935"}, + {file = "google-3.0.0.tar.gz", hash = "sha256:143530122ee5130509ad5e989f0512f7cb218b2d4eddbafbad40fd10e8d8ccbe"}, +] + +[package.dependencies] +beautifulsoup4 = "*" + [[package]] name = "google-ai-generativelanguage" version = "0.6.9" @@ -6670,6 +6695,17 @@ files = [ {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, ] +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -8012,6 +8048,21 @@ files = [ [package.dependencies] requests = "2.31.0" +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rich" version = "13.9.2" @@ -9829,6 +9880,26 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-compat" +version = "1.0.156" +description = "Be Compatible with the Volcengine SDK for Python, The version of package dependencies has been modified. like pycryptodome, pytz." +optional = false +python-versions = "*" +files = [ + {file = "volcengine_compat-1.0.156-py3-none-any.whl", hash = "sha256:4abc149a7601ebad8fa2d28fab50c7945145cf74daecb71bca797b0bdc82c5a5"}, + {file = "volcengine_compat-1.0.156.tar.gz", hash = "sha256:e357d096828e31a202dc6047bbc5bf6fff3f54a98cd35a99ab5f965ea741a267"}, +] + +[package.dependencies] +google = ">=3.0.0" +protobuf = ">=3.18.3" +pycryptodome = ">=3.9.9" +pytz = ">=2020.5" +requests = ">=2.25.1" +retry = ">=0.9.2" +six = ">=1.0" + [[package]] name = "volcengine-python-sdk" version = "1.0.103" @@ -10636,4 +10707,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774" +content-hash = "edb5e3b0d50e84a239224cc77f3f615fdbdd6b504bce5b1075b29363f3054957" diff --git a/api/pyproject.toml b/api/pyproject.toml index 594517771b34f2..dff74750f0f558 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -246,6 +246,7 @@ pymochow = "1.3.1" qdrant-client = "1.7.3" tcvectordb = "1.3.2" tidb-vector = "0.0.9" +volcengine-compat = "~1.0.156" weaviate-client = "~3.21.0" ############################################################ diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py new file mode 100644 index 00000000000000..0f40337feba6ee --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -0,0 +1,215 @@ +import os +from typing import Union +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from volcengine.viking_db import ( + Collection, + Data, + DistanceType, + Field, + FieldType, + Index, + IndexType, + QuantType, + VectorIndexParams, + VikingDBService, +) + +from core.rag.datasource.vdb.field import Field as vdb_Field + + +class MockVikingDBClass: + def __init__( + self, + host="api-vikingdb.volces.com", + region="cn-north-1", + ak="", + sk="", + scheme="http", + connection_timeout=30, + socket_timeout=30, + proxy=None, + ): + self._viking_db_service = MagicMock() + self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}') + + def get_collection(self, collection_name) -> Collection: + return Collection( + collection_name=collection_name, + description="Collection For Dify", + viking_db_service=self._viking_db_service, + primary_key=vdb_Field.PRIMARY_KEY.value, + fields=[ + Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), + ], + indexes=[ + Index( + collection_name=collection_name, + index_name=f"{collection_name}_idx", + vector_index=VectorIndexParams( + distance=DistanceType.L2, + index_type=IndexType.HNSW, + quant=QuantType.Float, + ), + scalar_index=None, + stat=None, + viking_db_service=self._viking_db_service, + ) + ], + ) + + def drop_collection(self, collection_name): + assert collection_name != "" + + def create_collection(self, collection_name, fields, description="") -> Collection: + return Collection( + collection_name=collection_name, + description=description, + primary_key=vdb_Field.PRIMARY_KEY.value, + viking_db_service=self._viking_db_service, + fields=fields, + ) + + def get_index(self, collection_name, index_name) -> Index: + return Index( + collection_name=collection_name, + index_name=index_name, + viking_db_service=self._viking_db_service, + stat=None, + scalar_index=None, + vector_index=VectorIndexParams( + distance=DistanceType.L2, + index_type=IndexType.HNSW, + quant=QuantType.Float, + ), + ) + + def create_index( + self, + collection_name, + index_name, + vector_index=None, + cpu_quota=2, + description="", + partition_by="", + scalar_index=None, + shard_count=None, + shard_policy=None, + ): + return Index( + collection_name=collection_name, + index_name=index_name, + vector_index=vector_index, + cpu_quota=cpu_quota, + description=description, + partition_by=partition_by, + scalar_index=scalar_index, + shard_count=shard_count, + shard_policy=shard_policy, + viking_db_service=self._viking_db_service, + stat=None, + ) + + def drop_index(self, collection_name, index_name): + assert collection_name != "" + assert index_name != "" + + def upsert_data(self, data: Union[Data, list[Data]]): + assert data is not None + + def fetch_data(self, id: Union[str, list[str], int, list[int]]): + return Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: "{}", + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: id, + vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + }, + id=id, + ) + + def delete_data(self, id: Union[str, list[str], int, list[int]]): + assert id is not None + + def search_by_vector( + self, + vector, + sparse_vectors=None, + filter=None, + limit=10, + output_fields=None, + partition="default", + dense_weight=None, + ) -> list[Data]: + return [ + Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: '\ + {"source": "/var/folders/ml/xxx/xxx.txt", \ + "document_id": "test_document_id", \ + "dataset_id": "test_dataset_id", \ + "doc_id": "test_id", \ + "doc_hash": "test_hash"}', + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: "test_id", + vdb_Field.VECTOR.value: vector, + }, + id="test_id", + score=0.10, + ) + ] + + def search( + self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None + ) -> list[Data]: + return [ + Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: '\ + {"source": "/var/folders/ml/xxx/xxx.txt", \ + "document_id": "test_document_id", \ + "dataset_id": "test_dataset_id", \ + "doc_id": "test_id", \ + "doc_hash": "test_hash"}', + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: "test_id", + vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + }, + id="test_id", + score=0.10, + ) + ] + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_vikingdb_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__) + monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection) + monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection) + monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection) + monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index) + monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index) + monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index) + monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data) + monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data) + monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data) + monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector) + monkeypatch.setattr(Index, "search", MockVikingDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/vikingdb/__init__.py b/api/tests/integration_tests/vdb/vikingdb/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py b/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py new file mode 100644 index 00000000000000..2572012ea03aa1 --- /dev/null +++ b/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py @@ -0,0 +1,37 @@ +from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector +from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class VikingDBVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = VikingDBVector( + "test_collection", + "test_group", + config=VikingDBConfig( + access_key="test_access_key", + host="test_host", + region="test_region", + scheme="test_scheme", + secret_key="test_secret_key", + connection_timeout=30, + socket_timeout=30, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id") + assert len(ids) > 0 + + +def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock): + VikingDBVectorTest().run_all_tests() diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index bad809cbfdb90a..6809ef7c6f74c2 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ - api/tests/integration_tests/vdb/elasticsearch \ No newline at end of file + api/tests/integration_tests/vdb/elasticsearch \ + api/tests/integration_tests/vdb/vikingdb diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index cb6ecba2791db1..5db11d1961ce00 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -173,6 +173,11 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} + VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-dify} + VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-dify} + VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} + VIKINGDB_HOST: ${VIKINGDB_HOST:-api-vikingdb.xxx.volces.com} + VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} From 857055b797b88fd0fc01a4f8423acd3c399454d5 Mon Sep 17 00:00:00 2001 From: kurokobo Date: Mon, 14 Oct 2024 00:25:50 +0900 Subject: [PATCH 19/25] fix: remove the latest message from the user that does not have any answer yet (#9297) --- api/core/memory/token_buffer_memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 72da3b0c6f251d..bc94912c1eb846 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -60,8 +60,8 @@ def get_history_prompt_messages( thread_messages = extract_thread_messages(messages) # for newly created message, its answer is temporarily empty, we don't need to add it to memory - if thread_messages and not thread_messages[-1].answer: - thread_messages.pop() + if thread_messages and not thread_messages[0].answer: + thread_messages.pop(0) messages = list(reversed(thread_messages)) From ffc3f3367058b48dab4a801d3031251339811f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Mon, 14 Oct 2024 10:53:45 +0800 Subject: [PATCH 20/25] chore: remove the copied zhipu_ai sdk (#9270) --- .../model_providers/zhipuai/llm/llm.py | 7 +- .../zhipuai/text_embedding/text_embedding.py | 3 +- .../zhipuai/zhipuai_sdk/__init__.py | 15 - .../zhipuai/zhipuai_sdk/__version__.py | 1 - .../zhipuai/zhipuai_sdk/_client.py | 82 -- .../zhipuai_sdk/api_resource/__init__.py | 34 - .../api_resource/assistant/__init__.py | 3 - .../api_resource/assistant/assistant.py | 122 --- .../zhipuai_sdk/api_resource/batches.py | 146 --- .../zhipuai_sdk/api_resource/chat/__init__.py | 5 - .../api_resource/chat/async_completions.py | 115 --- .../zhipuai_sdk/api_resource/chat/chat.py | 18 - .../api_resource/chat/completions.py | 108 --- .../zhipuai_sdk/api_resource/embeddings.py | 50 - .../zhipuai/zhipuai_sdk/api_resource/files.py | 194 ---- .../api_resource/fine_tuning/__init__.py | 5 - .../api_resource/fine_tuning/fine_tuning.py | 18 - .../api_resource/fine_tuning/jobs/__init__.py | 3 - .../api_resource/fine_tuning/jobs/jobs.py | 152 --- .../fine_tuning/models/__init__.py | 3 - .../fine_tuning/models/fine_tuned_models.py | 41 - .../zhipuai_sdk/api_resource/images.py | 59 -- .../api_resource/knowledge/__init__.py | 3 - .../knowledge/document/__init__.py | 3 - .../knowledge/document/document.py | 217 ----- .../api_resource/knowledge/knowledge.py | 173 ---- .../api_resource/tools/__init__.py | 3 - .../zhipuai_sdk/api_resource/tools/tools.py | 65 -- .../api_resource/videos/__init__.py | 7 - .../zhipuai_sdk/api_resource/videos/videos.py | 77 -- .../zhipuai/zhipuai_sdk/core/__init__.py | 108 --- .../zhipuai/zhipuai_sdk/core/_base_api.py | 19 - .../zhipuai/zhipuai_sdk/core/_base_compat.py | 209 ---- .../zhipuai/zhipuai_sdk/core/_base_models.py | 670 ------------- .../zhipuai/zhipuai_sdk/core/_base_type.py | 170 ---- .../zhipuai/zhipuai_sdk/core/_constants.py | 12 - .../zhipuai/zhipuai_sdk/core/_errors.py | 86 -- .../zhipuai/zhipuai_sdk/core/_files.py | 75 -- .../zhipuai/zhipuai_sdk/core/_http_client.py | 910 ------------------ .../zhipuai/zhipuai_sdk/core/_jwt_token.py | 31 - .../core/_legacy_binary_response.py | 207 ---- .../zhipuai_sdk/core/_legacy_response.py | 341 ------- .../zhipuai/zhipuai_sdk/core/_request_opt.py | 97 -- .../zhipuai/zhipuai_sdk/core/_response.py | 398 -------- .../zhipuai/zhipuai_sdk/core/_sse_client.py | 206 ---- .../zhipuai_sdk/core/_utils/__init__.py | 52 - .../zhipuai_sdk/core/_utils/_transform.py | 383 -------- .../zhipuai_sdk/core/_utils/_typing.py | 122 --- .../zhipuai/zhipuai_sdk/core/_utils/_utils.py | 409 -------- .../zhipuai/zhipuai_sdk/core/logs.py | 78 -- .../zhipuai/zhipuai_sdk/core/pagination.py | 62 -- .../zhipuai/zhipuai_sdk/types/__init__.py | 0 .../zhipuai_sdk/types/assistant/__init__.py | 5 - .../types/assistant/assistant_completion.py | 40 - .../assistant_conversation_params.py | 7 - .../assistant/assistant_conversation_resp.py | 29 - .../assistant/assistant_create_params.py | 32 - .../types/assistant/assistant_support_resp.py | 21 - .../types/assistant/message/__init__.py | 3 - .../assistant/message/message_content.py | 13 - .../assistant/message/text_content_block.py | 14 - .../tools/code_interpreter_delta_block.py | 27 - .../message/tools/drawing_tool_delta_block.py | 21 - .../message/tools/function_delta_block.py | 22 - .../message/tools/retrieval_delta_black.py | 41 - .../assistant/message/tools/tools_type.py | 16 - .../message/tools/web_browser_delta_block.py | 48 - .../assistant/message/tools_delta_block.py | 16 - .../zhipuai/zhipuai_sdk/types/batch.py | 82 -- .../zhipuai_sdk/types/batch_create_params.py | 37 - .../zhipuai/zhipuai_sdk/types/batch_error.py | 21 - .../zhipuai_sdk/types/batch_list_params.py | 20 - .../zhipuai_sdk/types/batch_request_counts.py | 14 - .../zhipuai_sdk/types/chat/__init__.py | 0 .../types/chat/async_chat_completion.py | 22 - .../zhipuai_sdk/types/chat/chat_completion.py | 43 - .../types/chat/chat_completion_chunk.py | 57 -- .../chat/chat_completions_create_param.py | 8 - .../types/chat/code_geex/code_geex_params.py | 146 --- .../zhipuai/zhipuai_sdk/types/embeddings.py | 21 - .../zhipuai_sdk/types/files/__init__.py | 5 - .../types/files/file_create_params.py | 38 - .../zhipuai_sdk/types/files/file_deleted.py | 13 - .../zhipuai_sdk/types/files/file_object.py | 22 - .../zhipuai_sdk/types/files/upload_detail.py | 13 - .../zhipuai_sdk/types/fine_tuning/__init__.py | 4 - .../types/fine_tuning/fine_tuning_job.py | 51 - .../fine_tuning/fine_tuning_job_event.py | 35 - .../types/fine_tuning/job_create_params.py | 15 - .../types/fine_tuning/models/__init__.py | 1 - .../fine_tuning/models/fine_tuned_models.py | 13 - .../zhipuai/zhipuai_sdk/types/image.py | 18 - .../zhipuai_sdk/types/knowledge/__init__.py | 8 - .../types/knowledge/document/__init__.py | 8 - .../types/knowledge/document/document.py | 51 - .../document/document_edit_params.py | 29 - .../document/document_list_params.py | 26 - .../knowledge/document/document_list_resp.py | 11 - .../zhipuai_sdk/types/knowledge/knowledge.py | 21 - .../knowledge/knowledge_create_params.py | 30 - .../types/knowledge/knowledge_list_params.py | 15 - .../types/knowledge/knowledge_list_resp.py | 11 - .../types/knowledge/knowledge_used.py | 21 - .../types/sensitive_word_check/__init__.py | 3 - .../sensitive_word_check.py | 14 - .../zhipuai_sdk/types/tools/__init__.py | 9 - .../types/tools/tools_web_search_params.py | 35 - .../zhipuai_sdk/types/tools/web_search.py | 71 -- .../types/tools/web_search_chunk.py | 33 - .../zhipuai_sdk/types/video/__init__.py | 3 - .../types/video/video_create_params.py | 27 - .../zhipuai_sdk/types/video/video_object.py | 30 - .../builtin/cogview/tools/cogview3.py | 3 +- api/poetry.lock | 36 +- api/pyproject.toml | 2 +- 115 files changed, 19 insertions(+), 7909 deletions(-) delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py delete mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index e0c49805230759..43bffad2a0bc34 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,6 +1,10 @@ from collections.abc import Generator from typing import Optional, Union +from zhipuai import ZhipuAI +from zhipuai.types.chat.chat_completion import Completion +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -16,9 +20,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.utils import helper GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 14a529dddf82d1..5a34a3d5939974 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,13 +1,14 @@ import time from typing import Optional +from zhipuai import ZhipuAI + from core.embedding.embedding_constant import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py deleted file mode 100644 index fc71d64714bd96..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .__version__ import __version__ -from ._client import ZhipuAI -from .core import ( - APIAuthenticationError, - APIConnectionError, - APIInternalError, - APIReachLimitError, - APIRequestFailedError, - APIResponseError, - APIResponseValidationError, - APIServerFlowExceedError, - APIStatusError, - APITimeoutError, - ZhipuAIError, -) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py deleted file mode 100644 index 51f8c49ecb827d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "v2.1.0" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py deleted file mode 100644 index 705d371e628f08..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -import os -from collections.abc import Mapping -from typing import Union - -import httpx -from httpx import Timeout -from typing_extensions import override - -from . import api_resource -from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token - - -class ZhipuAI(HttpClient): - chat: api_resource.chat.Chat - api_key: str - _disable_token_cache: bool = True - - def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, - disable_token_cache: bool = True, - _strict_response_validation: bool = False, - ) -> None: - if api_key is None: - api_key = os.environ.get("ZHIPUAI_API_KEY") - if api_key is None: - raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") - self.api_key = api_key - self._disable_token_cache = disable_token_cache - - if base_url is None: - base_url = os.environ.get("ZHIPUAI_BASE_URL") - if base_url is None: - base_url = "https://open.bigmodel.cn/api/paas/v4" - from .__version__ import __version__ - - super().__init__( - version=__version__, - base_url=base_url, - max_retries=max_retries, - timeout=timeout, - custom_httpx_client=http_client, - custom_headers=custom_headers, - _strict_response_validation=_strict_response_validation, - ) - self.chat = api_resource.chat.Chat(self) - self.images = api_resource.images.Images(self) - self.embeddings = api_resource.embeddings.Embeddings(self) - self.files = api_resource.files.Files(self) - self.fine_tuning = api_resource.fine_tuning.FineTuning(self) - self.batches = api_resource.Batches(self) - self.knowledge = api_resource.Knowledge(self) - self.tools = api_resource.Tools(self) - self.videos = api_resource.Videos(self) - self.assistant = api_resource.Assistant(self) - - @property - @override - def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if self._disable_token_cache: - return {"Authorization": f"Bearer {api_key}"} - else: - return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"} - - def __del__(self) -> None: - if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): - # if the '__init__' method raised an error, self would not have client attr - return - - if self._has_custom_http_client: - return - - self.close() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py deleted file mode 100644 index 4fe0719dde3e0b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from .assistant import ( - Assistant, -) -from .batches import Batches -from .chat import ( - AsyncCompletions, - Chat, - Completions, -) -from .embeddings import Embeddings -from .files import Files, FilesWithRawResponse -from .fine_tuning import FineTuning -from .images import Images -from .knowledge import Knowledge -from .tools import Tools -from .videos import ( - Videos, -) - -__all__ = [ - "Videos", - "AsyncCompletions", - "Chat", - "Completions", - "Images", - "Embeddings", - "Files", - "FilesWithRawResponse", - "FineTuning", - "Batches", - "Knowledge", - "Tools", - "Assistant", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py deleted file mode 100644 index ce619aa7f09222..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .assistant import Assistant - -__all__ = ["Assistant"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py deleted file mode 100644 index c29b05749847cc..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - StreamResponse, - deepcopy_minimal, - make_request_options, - maybe_transform, -) -from ...types.assistant import AssistantCompletion -from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp -from ...types.assistant.assistant_support_resp import AssistantSupportResp - -if TYPE_CHECKING: - from ..._client import ZhipuAI - -from ...types.assistant import assistant_conversation_params, assistant_create_params - -__all__ = ["Assistant"] - - -class Assistant(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def conversation( - self, - assistant_id: str, - model: str, - messages: list[assistant_create_params.ConversationMessage], - *, - stream: bool = True, - conversation_id: Optional[str] = None, - attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None, - metadata: dict | None = None, - request_id: Optional[str] = None, - user_id: Optional[str] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> StreamResponse[AssistantCompletion]: - body = deepcopy_minimal( - { - "assistant_id": assistant_id, - "model": model, - "messages": messages, - "stream": stream, - "conversation_id": conversation_id, - "attachments": attachments, - "metadata": metadata, - "request_id": request_id, - "user_id": user_id, - } - ) - return self._post( - "/assistant", - body=maybe_transform(body, assistant_create_params.AssistantParameters), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=AssistantCompletion, - stream=stream or True, - stream_cls=StreamResponse[AssistantCompletion], - ) - - def query_support( - self, - *, - assistant_id_list: Optional[list[str]] = None, - request_id: Optional[str] = None, - user_id: Optional[str] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AssistantSupportResp: - body = deepcopy_minimal( - { - "assistant_id_list": assistant_id_list, - "request_id": request_id, - "user_id": user_id, - } - ) - return self._post( - "/assistant/list", - body=body, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=AssistantSupportResp, - ) - - def query_conversation_usage( - self, - assistant_id: str, - page: int = 1, - page_size: int = 10, - *, - request_id: Optional[str] = None, - user_id: Optional[str] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ConversationUsageListResp: - body = deepcopy_minimal( - { - "assistant_id": assistant_id, - "page": page, - "page_size": page_size, - "request_id": request_id, - "user_id": user_id, - } - ) - return self._post( - "/assistant/conversation/list", - body=maybe_transform(body, assistant_conversation_params.ConversationParameters), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=ConversationUsageListResp, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py deleted file mode 100644 index ae2f2be85eb9b4..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, Optional - -import httpx - -from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform -from ..core.pagination import SyncCursorPage -from ..types import batch_create_params, batch_list_params -from ..types.batch import Batch - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class Batches(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - completion_window: str | None = None, - endpoint: Literal["/v1/chat/completions", "/v1/embeddings"], - input_file_id: str, - metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, - auto_delete_input_file: bool = True, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Batch: - return self._post( - "/batches", - body=maybe_transform( - { - "completion_window": completion_window, - "endpoint": endpoint, - "input_file_id": input_file_id, - "metadata": metadata, - "auto_delete_input_file": auto_delete_input_file, - }, - batch_create_params.BatchCreateParams, - ), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=Batch, - ) - - def retrieve( - self, - batch_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Batch: - """ - Retrieves a batch. - - Args: - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not batch_id: - raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") - return self._get( - f"/batches/{batch_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=Batch, - ) - - def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> SyncCursorPage[Batch]: - """List your organization's batches. - - Args: - after: A cursor for use in pagination. - - `after` is an object ID that defines your place - in the list. For instance, if you make a list request and receive 100 objects, - ending with obj_foo, your subsequent call can include after=obj_foo in order to - fetch the next page of the list. - - limit: A limit on the number of objects to be returned. Limit can range between 1 and - 100, and the default is 20. - - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._get_api_list( - "/batches", - page=SyncCursorPage[Batch], - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - { - "after": after, - "limit": limit, - }, - batch_list_params.BatchListParams, - ), - ), - model=Batch, - ) - - def cancel( - self, - batch_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Batch: - """ - Cancels an in-progress batch. - - Args: - batch_id: The ID of the batch to cancel. - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - - """ - if not batch_id: - raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") - return self._post( - f"/batches/{batch_id}/cancel", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=Batch, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py deleted file mode 100644 index 5cd8dc6f339a60..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .async_completions import AsyncCompletions -from .chat import Chat -from .completions import Completions - -__all__ = ["AsyncCompletions", "Chat", "Completions"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py deleted file mode 100644 index 05510a3ec421d0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Literal, Optional, Union - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - drop_prefix_image_data, - make_request_options, - maybe_transform, -) -from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus -from ...types.chat.code_geex import code_geex_params -from ...types.sensitive_word_check import SensitiveWordCheckRequest - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class AsyncCompletions(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - user_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, - extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncTaskStatus: - _cast_type = AsyncTaskStatus - logger.debug(f"temperature:{temperature}, top_p:{top_p}") - if temperature is not None and temperature != NOT_GIVEN: - if temperature <= 0: - do_sample = False - temperature = 0.01 - # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501 - if temperature >= 1: - temperature = 0.99 - # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") - if top_p is not None and top_p != NOT_GIVEN: - if top_p >= 1: - top_p = 0.99 - # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") - if top_p <= 0: - top_p = 0.01 - # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") - - logger.debug(f"temperature:{temperature}, top_p:{top_p}") - if isinstance(messages, list): - for item in messages: - if item.get("content"): - item["content"] = drop_prefix_image_data(item["content"]) - - body = { - "model": model, - "request_id": request_id, - "user_id": user_id, - "temperature": temperature, - "top_p": top_p, - "do_sample": do_sample, - "max_tokens": max_tokens, - "seed": seed, - "messages": messages, - "stop": stop, - "sensitive_word_check": sensitive_word_check, - "tools": tools, - "tool_choice": tool_choice, - "meta": meta, - "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), - } - return self._post( - "/async/chat/completions", - body=body, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=_cast_type, - stream=False, - ) - - def retrieve_completion_result( - self, - id: str, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion, AsyncTaskStatus] - return self._get( - path=f"/async-result/{id}", - cast_type=_cast_type, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py deleted file mode 100644 index b3cc46566c7bf3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import TYPE_CHECKING - -from ...core import BaseAPI, cached_property -from .async_completions import AsyncCompletions -from .completions import Completions - -if TYPE_CHECKING: - pass - - -class Chat(BaseAPI): - @cached_property - def completions(self) -> Completions: - return Completions(self._client) - - @cached_property - def asyncCompletions(self) -> AsyncCompletions: # noqa: N802 - return AsyncCompletions(self._client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py deleted file mode 100644 index 8e5bb454e6ce7e..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Literal, Optional, Union - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - StreamResponse, - deepcopy_minimal, - drop_prefix_image_data, - make_request_options, - maybe_transform, -) -from ...types.chat.chat_completion import Completion -from ...types.chat.chat_completion_chunk import ChatCompletionChunk -from ...types.chat.code_geex import code_geex_params -from ...types.sensitive_word_check import SensitiveWordCheckRequest - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class Completions(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - user_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, - extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Completion | StreamResponse[ChatCompletionChunk]: - logger.debug(f"temperature:{temperature}, top_p:{top_p}") - if temperature is not None and temperature != NOT_GIVEN: - if temperature <= 0: - do_sample = False - temperature = 0.01 - # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501 - if temperature >= 1: - temperature = 0.99 - # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") - if top_p is not None and top_p != NOT_GIVEN: - if top_p >= 1: - top_p = 0.99 - # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") - if top_p <= 0: - top_p = 0.01 - # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") - - logger.debug(f"temperature:{temperature}, top_p:{top_p}") - if isinstance(messages, list): - for item in messages: - if item.get("content"): - item["content"] = drop_prefix_image_data(item["content"]) - - body = deepcopy_minimal( - { - "model": model, - "request_id": request_id, - "user_id": user_id, - "temperature": temperature, - "top_p": top_p, - "do_sample": do_sample, - "max_tokens": max_tokens, - "seed": seed, - "messages": messages, - "stop": stop, - "sensitive_word_check": sensitive_word_check, - "stream": stream, - "tools": tools, - "tool_choice": tool_choice, - "meta": meta, - "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), - } - ) - return self._post( - "/chat/completions", - body=body, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=Completion, - stream=stream or False, - stream_cls=StreamResponse[ChatCompletionChunk], - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py deleted file mode 100644 index 4b4baef9421ba6..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, Union - -import httpx - -from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options -from ..types.embeddings import EmbeddingsResponded - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class Embeddings(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - dimensions: Union[int] | NotGiven = NOT_GIVEN, - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EmbeddingsResponded: - _cast_type = EmbeddingsResponded - if disable_strict_validation: - _cast_type = object - return self._post( - "/embeddings", - body={ - "input": input, - "model": model, - "dimensions": dimensions, - "encoding_format": encoding_format, - "user": user, - "request_id": request_id, - "sensitive_word_check": sensitive_word_check, - }, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=_cast_type, - stream=False, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py deleted file mode 100644 index c723f6f66e41cf..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import TYPE_CHECKING, Literal, Optional, cast - -import httpx - -from ..core import ( - NOT_GIVEN, - BaseAPI, - Body, - FileTypes, - Headers, - NotGiven, - _legacy_binary_response, - _legacy_response, - deepcopy_minimal, - extract_files, - make_request_options, - maybe_transform, -) -from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params - -if TYPE_CHECKING: - from .._client import ZhipuAI - -__all__ = ["Files", "FilesWithRawResponse"] - - -class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - file: Optional[FileTypes] = None, - upload_detail: Optional[list[UploadDetail]] = None, - purpose: Literal["fine-tune", "retrieval", "batch"], - knowledge_id: Optional[str] = None, - sentence_size: Optional[int] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FileObject: - if not file and not upload_detail: - raise ValueError("At least one of `file` and `upload_detail` must be provided.") - body = deepcopy_minimal( - { - "file": file, - "upload_detail": upload_detail, - "purpose": purpose, - "knowledge_id": knowledge_id, - "sentence_size": sentence_size, - } - ) - files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return self._post( - "/files", - body=maybe_transform(body, file_create_params.FileCreateParams), - files=files, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FileObject, - ) - - # def retrieve( - # self, - # file_id: str, - # *, - # extra_headers: Headers | None = None, - # extra_body: Body | None = None, - # timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - # ) -> FileObject: - # """ - # Returns information about a specific file. - # - # Args: - # file_id: The ID of the file to retrieve information about - # extra_headers: Send extra headers - # - # extra_body: Add additional JSON properties to the request - # - # timeout: Override the client-level default timeout for this request, in seconds - # """ - # if not file_id: - # raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") - # return self._get( - # f"/files/{file_id}", - # options=make_request_options( - # extra_headers=extra_headers, extra_body=extra_body, timeout=timeout - # ), - # cast_type=FileObject, - # ) - - def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ListOfFileObject: - return self._get( - "/files", - cast_type=ListOfFileObject, - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query={ - "purpose": purpose, - "limit": limit, - "after": after, - "order": order, - }, - ), - ) - - def delete( - self, - file_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FileDeleted: - """ - Delete a file. - - Args: - file_id: The ID of the file to delete - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not file_id: - raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") - return self._delete( - f"/files/{file_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FileDeleted, - ) - - def content( - self, - file_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> _legacy_response.HttpxBinaryResponseContent: - """ - Returns the contents of the specified file. - - Args: - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not file_id: - raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") - extra_headers = {"Accept": "application/binary", **(extra_headers or {})} - return self._get( - f"/files/{file_id}/content", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=_legacy_binary_response.HttpxBinaryResponseContent, - ) - - -class FilesWithRawResponse: - def __init__(self, files: Files) -> None: - self._files = files - - self.create = _legacy_response.to_raw_response_wrapper( - files.create, - ) - self.list = _legacy_response.to_raw_response_wrapper( - files.list, - ) - self.content = _legacy_response.to_raw_response_wrapper( - files.content, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py deleted file mode 100644 index 7c309b83416803..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .fine_tuning import FineTuning -from .jobs import Jobs -from .models import FineTunedModels - -__all__ = ["Jobs", "FineTunedModels", "FineTuning"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py deleted file mode 100644 index 8670f7de00df84..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import TYPE_CHECKING - -from ...core import BaseAPI, cached_property -from .jobs import Jobs -from .models import FineTunedModels - -if TYPE_CHECKING: - pass - - -class FineTuning(BaseAPI): - @cached_property - def jobs(self) -> Jobs: - return Jobs(self._client) - - @cached_property - def models(self) -> FineTunedModels: - return FineTunedModels(self._client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py deleted file mode 100644 index 40777a153f272a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .jobs import Jobs - -__all__ = ["Jobs"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py deleted file mode 100644 index 8b038cadc06407..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ....core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - make_request_options, -) -from ....types.fine_tuning import ( - FineTuningJob, - FineTuningJobEvent, - ListOfFineTuningJob, - job_create_params, -) - -if TYPE_CHECKING: - from ...._client import ZhipuAI - -__all__ = ["Jobs"] - - -class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - return self._post( - "/fine_tuning/jobs", - body={ - "model": model, - "training_file": training_file, - "hyperparameters": hyperparameters, - "suffix": suffix, - "validation_file": validation_file, - "request_id": request_id, - }, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FineTuningJob, - ) - - def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - return self._get( - f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FineTuningJob, - ) - - def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ListOfFineTuningJob: - return self._get( - "/fine_tuning/jobs", - cast_type=ListOfFineTuningJob, - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query={ - "after": after, - "limit": limit, - }, - ), - ) - - def cancel( - self, - fine_tuning_job_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501 - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - if not fine_tuning_job_id: - raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") - return self._post( - f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FineTuningJob, - ) - - def list_events( - self, - fine_tuning_job_id: str, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJobEvent: - return self._get( - f"/fine_tuning/jobs/{fine_tuning_job_id}/events", - cast_type=FineTuningJobEvent, - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query={ - "after": after, - "limit": limit, - }, - ), - ) - - def delete( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - if not fine_tuning_job_id: - raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") - return self._delete( - f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FineTuningJob, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py deleted file mode 100644 index d832635bafbc6f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fine_tuned_models import FineTunedModels - -__all__ = ["FineTunedModels"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py deleted file mode 100644 index 29c023e3b1cd5a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import httpx - -from ....core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - make_request_options, -) -from ....types.fine_tuning.models import FineTunedModelsStatus - -if TYPE_CHECKING: - from ...._client import ZhipuAI - -__all__ = ["FineTunedModels"] - - -class FineTunedModels(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def delete( - self, - fine_tuned_model: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTunedModelsStatus: - if not fine_tuned_model: - raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}") - return self._delete( - f"fine_tuning/fine_tuned_models/{fine_tuned_model}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=FineTunedModelsStatus, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py deleted file mode 100644 index 8ad411913fa115..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options -from ..types.image import ImagesResponded -from ..types.sensitive_word_check import SensitiveWordCheckRequest - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class Images(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - user_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ImagesResponded: - _cast_type = ImagesResponded - if disable_strict_validation: - _cast_type = object - return self._post( - "/images/generations", - body={ - "prompt": prompt, - "model": model, - "n": n, - "quality": quality, - "response_format": response_format, - "sensitive_word_check": sensitive_word_check, - "size": size, - "style": style, - "user": user, - "user_id": user_id, - "request_id": request_id, - }, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=_cast_type, - stream=False, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py deleted file mode 100644 index 5a67d743c35b9b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .knowledge import Knowledge - -__all__ = ["Knowledge"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py deleted file mode 100644 index fd289e2232b955..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .document import Document - -__all__ = ["Document"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py deleted file mode 100644 index 492c49da6636c2..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import TYPE_CHECKING, Literal, Optional, cast - -import httpx - -from ....core import ( - NOT_GIVEN, - BaseAPI, - Body, - FileTypes, - Headers, - NotGiven, - deepcopy_minimal, - extract_files, - make_request_options, - maybe_transform, -) -from ....types.files import UploadDetail, file_create_params -from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params -from ....types.knowledge.document.document_list_resp import DocumentPage - -if TYPE_CHECKING: - from ...._client import ZhipuAI - -__all__ = ["Document"] - - -class Document(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - file: Optional[FileTypes] = None, - custom_separator: Optional[list[str]] = None, - upload_detail: Optional[list[UploadDetail]] = None, - purpose: Literal["retrieval"], - knowledge_id: Optional[str] = None, - sentence_size: Optional[int] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DocumentObject: - if not file and not upload_detail: - raise ValueError("At least one of `file` and `upload_detail` must be provided.") - body = deepcopy_minimal( - { - "file": file, - "upload_detail": upload_detail, - "purpose": purpose, - "custom_separator": custom_separator, - "knowledge_id": knowledge_id, - "sentence_size": sentence_size, - } - ) - files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - return self._post( - "/files", - body=maybe_transform(body, file_create_params.FileCreateParams), - files=files, - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=DocumentObject, - ) - - def edit( - self, - document_id: str, - knowledge_type: str, - *, - custom_separator: Optional[list[str]] = None, - sentence_size: Optional[int] = None, - callback_url: Optional[str] = None, - callback_header: Optional[dict[str, str]] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> httpx.Response: - """ - - Args: - document_id: 知识id - knowledge_type: 知识类型: - 1:文章知识: 支持pdf,url,docx - 2.问答知识-文档: 支持pdf,url,docx - 3.问答知识-表格: 支持xlsx - 4.商品库-表格: 支持xlsx - 5.自定义: 支持pdf,url,docx - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - :param knowledge_type: - :param document_id: - :param timeout: - :param extra_body: - :param callback_header: - :param sentence_size: - :param extra_headers: - :param callback_url: - :param custom_separator: - """ - if not document_id: - raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") - - body = deepcopy_minimal( - { - "id": document_id, - "knowledge_type": knowledge_type, - "custom_separator": custom_separator, - "sentence_size": sentence_size, - "callback_url": callback_url, - "callback_header": callback_header, - } - ) - - return self._put( - f"/document/{document_id}", - body=maybe_transform(body, document_edit_params.DocumentEditParams), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=httpx.Response, - ) - - def list( - self, - knowledge_id: str, - *, - purpose: str | NotGiven = NOT_GIVEN, - page: str | NotGiven = NOT_GIVEN, - limit: str | NotGiven = NOT_GIVEN, - order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DocumentPage: - return self._get( - "/files", - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - { - "knowledge_id": knowledge_id, - "purpose": purpose, - "page": page, - "limit": limit, - "order": order, - }, - document_list_params.DocumentListParams, - ), - ), - cast_type=DocumentPage, - ) - - def delete( - self, - document_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> httpx.Response: - """ - Delete a file. - - Args: - - document_id: 知识id - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not document_id: - raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") - - return self._delete( - f"/document/{document_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=httpx.Response, - ) - - def retrieve( - self, - document_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DocumentData: - """ - - Args: - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not document_id: - raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") - - return self._get( - f"/document/{document_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=DocumentData, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py deleted file mode 100644 index fea4c73ac997c3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, Optional - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - cached_property, - deepcopy_minimal, - make_request_options, - maybe_transform, -) -from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params -from ...types.knowledge.knowledge_list_resp import KnowledgePage -from .document import Document - -if TYPE_CHECKING: - from ..._client import ZhipuAI - -__all__ = ["Knowledge"] - - -class Knowledge(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - @cached_property - def document(self) -> Document: - return Document(self._client) - - def create( - self, - embedding_id: int, - name: str, - *, - customer_identifier: Optional[str] = None, - description: Optional[str] = None, - background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, - icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, - bucket_id: Optional[str] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> KnowledgeInfo: - body = deepcopy_minimal( - { - "embedding_id": embedding_id, - "name": name, - "customer_identifier": customer_identifier, - "description": description, - "background": background, - "icon": icon, - "bucket_id": bucket_id, - } - ) - return self._post( - "/knowledge", - body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=KnowledgeInfo, - ) - - def modify( - self, - knowledge_id: str, - embedding_id: int, - *, - name: str, - description: Optional[str] = None, - background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, - icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> httpx.Response: - body = deepcopy_minimal( - { - "id": knowledge_id, - "embedding_id": embedding_id, - "name": name, - "description": description, - "background": background, - "icon": icon, - } - ) - return self._put( - f"/knowledge/{knowledge_id}", - body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=httpx.Response, - ) - - def query( - self, - *, - page: int | NotGiven = 1, - size: int | NotGiven = 10, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> KnowledgePage: - return self._get( - "/knowledge", - options=make_request_options( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - { - "page": page, - "size": size, - }, - knowledge_list_params.KnowledgeListParams, - ), - ), - cast_type=KnowledgePage, - ) - - def delete( - self, - knowledge_id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> httpx.Response: - """ - Delete a file. - - Args: - knowledge_id: 知识库ID - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not knowledge_id: - raise ValueError("Expected a non-empty value for `knowledge_id`") - - return self._delete( - f"/knowledge/{knowledge_id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=httpx.Response, - ) - - def used( - self, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> KnowledgeUsed: - """ - Returns the contents of the specified file. - - Args: - extra_headers: Send extra headers - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._get( - "/knowledge/capacity", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=KnowledgeUsed, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py deleted file mode 100644 index 43e4e37da1779f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tools import Tools - -__all__ = ["Tools"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py deleted file mode 100644 index 3c3a630aff47d7..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Literal, Optional, Union - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - StreamResponse, - deepcopy_minimal, - make_request_options, - maybe_transform, -) -from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from ..._client import ZhipuAI - -__all__ = ["Tools"] - - -class Tools(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def web_search( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - scope: Optional[str] | NotGiven = NOT_GIVEN, - location: Optional[str] | NotGiven = NOT_GIVEN, - recent_days: Optional[int] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> WebSearch | StreamResponse[WebSearchChunk]: - body = deepcopy_minimal( - { - "model": model, - "request_id": request_id, - "messages": messages, - "stream": stream, - "scope": scope, - "location": location, - "recent_days": recent_days, - } - ) - return self._post( - "/tools", - body=maybe_transform(body, tools_web_search_params.WebSearchParams), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=WebSearch, - stream=stream or False, - stream_cls=StreamResponse[WebSearchChunk], - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py deleted file mode 100644 index 6b0f99ed09efe3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .videos import ( - Videos, -) - -__all__ = [ - "Videos", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py deleted file mode 100644 index 71c8316602a089..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ...core import ( - NOT_GIVEN, - BaseAPI, - Body, - Headers, - NotGiven, - deepcopy_minimal, - make_request_options, - maybe_transform, -) -from ...types.sensitive_word_check import SensitiveWordCheckRequest -from ...types.video import VideoObject, video_create_params - -if TYPE_CHECKING: - from ..._client import ZhipuAI - -__all__ = ["Videos"] - - -class Videos(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def generations( - self, - model: str, - *, - prompt: Optional[str] = None, - image_url: Optional[str] = None, - sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, - request_id: Optional[str] = None, - user_id: Optional[str] = None, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VideoObject: - if not model and not model: - raise ValueError("At least one of `model` and `prompt` must be provided.") - body = deepcopy_minimal( - { - "model": model, - "prompt": prompt, - "image_url": image_url, - "sensitive_word_check": sensitive_word_check, - "request_id": request_id, - "user_id": user_id, - } - ) - return self._post( - "/videos/generations", - body=maybe_transform(body, video_create_params.VideoCreateParams), - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=VideoObject, - ) - - def retrieve_videos_result( - self, - id: str, - *, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VideoObject: - if not id: - raise ValueError("At least one of `id` must be provided.") - - return self._get( - f"/async-result/{id}", - options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), - cast_type=VideoObject, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py deleted file mode 100644 index 3d6466d279861a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -from ._base_api import BaseAPI -from ._base_compat import ( - PYDANTIC_V2, - ConfigDict, - GenericModel, - cached_property, - field_get_default, - get_args, - get_model_config, - get_model_fields, - get_origin, - is_literal_type, - is_union, - parse_obj, -) -from ._base_models import BaseModel, construct_type -from ._base_type import ( - NOT_GIVEN, - Body, - FileTypes, - Headers, - IncEx, - ModelT, - NotGiven, - Query, -) -from ._constants import ( - ZHIPUAI_DEFAULT_LIMITS, - ZHIPUAI_DEFAULT_MAX_RETRIES, - ZHIPUAI_DEFAULT_TIMEOUT, -) -from ._errors import ( - APIAuthenticationError, - APIConnectionError, - APIInternalError, - APIReachLimitError, - APIRequestFailedError, - APIResponseError, - APIResponseValidationError, - APIServerFlowExceedError, - APIStatusError, - APITimeoutError, - ZhipuAIError, -) -from ._files import is_file_content -from ._http_client import HttpClient, make_request_options -from ._sse_client import StreamResponse -from ._utils import ( - deepcopy_minimal, - drop_prefix_image_data, - extract_files, - is_given, - is_list, - is_mapping, - maybe_transform, - parse_date, - parse_datetime, -) - -__all__ = [ - "BaseModel", - "construct_type", - "BaseAPI", - "NOT_GIVEN", - "Headers", - "NotGiven", - "Body", - "IncEx", - "ModelT", - "Query", - "FileTypes", - "PYDANTIC_V2", - "ConfigDict", - "GenericModel", - "get_args", - "is_union", - "parse_obj", - "get_origin", - "is_literal_type", - "get_model_config", - "get_model_fields", - "field_get_default", - "is_file_content", - "ZhipuAIError", - "APIStatusError", - "APIRequestFailedError", - "APIAuthenticationError", - "APIReachLimitError", - "APIInternalError", - "APIServerFlowExceedError", - "APIResponseError", - "APIResponseValidationError", - "APITimeoutError", - "make_request_options", - "HttpClient", - "ZHIPUAI_DEFAULT_TIMEOUT", - "ZHIPUAI_DEFAULT_MAX_RETRIES", - "ZHIPUAI_DEFAULT_LIMITS", - "is_list", - "is_mapping", - "parse_date", - "parse_datetime", - "is_given", - "maybe_transform", - "deepcopy_minimal", - "extract_files", - "StreamResponse", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py deleted file mode 100644 index 3592ea6bacd170..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class BaseAPI: - _client: ZhipuAI - - def __init__(self, client: ZhipuAI) -> None: - self._client = client - self._delete = client.delete - self._get = client.get - self._post = client.post - self._put = client.put - self._patch = client.patch - self._get_api_list = client.get_api_list diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py deleted file mode 100644 index 92a5d683be6732..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload - -import pydantic -from pydantic.fields import FieldInfo -from typing_extensions import Self - -from ._base_type import StrBytesIntFloat - -_T = TypeVar("_T") -_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) - -# --------------- Pydantic v2 compatibility --------------- - -# Pyright incorrectly reports some of our functions as overriding a method when they don't -# pyright: reportIncompatibleMethodOverride=false - -PYDANTIC_V2 = pydantic.VERSION.startswith("2.") - -# v1 re-exports -if TYPE_CHECKING: - - def parse_date(value: date | StrBytesIntFloat) -> date: ... - - def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ... - - def get_args(t: type[Any]) -> tuple[Any, ...]: ... - - def is_union(tp: type[Any] | None) -> bool: ... - - def get_origin(t: type[Any]) -> type[Any] | None: ... - - def is_literal_type(type_: type[Any]) -> bool: ... - - def is_typeddict(type_: type[Any]) -> bool: ... - -else: - if PYDANTIC_V2: - from pydantic.v1.typing import ( # noqa: I001 - get_args as get_args, # noqa: PLC0414 - is_union as is_union, # noqa: PLC0414 - get_origin as get_origin, # noqa: PLC0414 - is_typeddict as is_typeddict, # noqa: PLC0414 - is_literal_type as is_literal_type, # noqa: PLC0414 - ) - from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 - else: - from pydantic.typing import ( # noqa: I001 - get_args as get_args, # noqa: PLC0414 - is_union as is_union, # noqa: PLC0414 - get_origin as get_origin, # noqa: PLC0414 - is_typeddict as is_typeddict, # noqa: PLC0414 - is_literal_type as is_literal_type, # noqa: PLC0414 - ) - from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 - - -# refactored config -if TYPE_CHECKING: - from pydantic import ConfigDict -else: - if PYDANTIC_V2: - from pydantic import ConfigDict - else: - # TODO: provide an error message here? - ConfigDict = None - - -# renamed methods / properties -def parse_obj(model: type[_ModelT], value: object) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(value) - else: - # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - return cast(_ModelT, model.parse_obj(value)) - - -def field_is_required(field: FieldInfo) -> bool: - if PYDANTIC_V2: - return field.is_required() - return field.required # type: ignore - - -def field_get_default(field: FieldInfo) -> Any: - value = field.get_default() - if PYDANTIC_V2: - from pydantic_core import PydanticUndefined - - if value == PydanticUndefined: - return None - return value - return value - - -def field_outer_type(field: FieldInfo) -> Any: - if PYDANTIC_V2: - return field.annotation - return field.outer_type_ # type: ignore - - -def get_model_config(model: type[pydantic.BaseModel]) -> Any: - if PYDANTIC_V2: - return model.model_config - return model.__config__ # type: ignore - - -def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: - if PYDANTIC_V2: - return model.model_fields - return model.__fields__ # type: ignore - - -def model_copy(model: _ModelT) -> _ModelT: - if PYDANTIC_V2: - return model.model_copy() - return model.copy() # type: ignore - - -def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: - if PYDANTIC_V2: - return model.model_dump_json(indent=indent) - return model.json(indent=indent) # type: ignore - - -def model_dump( - model: pydantic.BaseModel, - *, - exclude_unset: bool = False, - exclude_defaults: bool = False, -) -> dict[str, Any]: - if PYDANTIC_V2: - return model.model_dump( - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - ) - return cast( - "dict[str, Any]", - model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - ), - ) - - -def model_parse(model: type[_ModelT], data: Any) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(data) - return model.parse_obj(data) # pyright: ignore[reportDeprecated] - - -# generic models -if TYPE_CHECKING: - - class GenericModel(pydantic.BaseModel): ... - -else: - if PYDANTIC_V2: - # there no longer needs to be a distinction in v2 but - # we still have to create our own subclass to avoid - # inconsistent MRO ordering errors - class GenericModel(pydantic.BaseModel): ... - - else: - import pydantic.generics - - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... - - -# cached properties -if TYPE_CHECKING: - cached_property = property - - # we define a separate type (copied from typeshed) - # that represents that `cached_property` is `set`able - # at runtime, which differs from `@property`. - # - # this is a separate type as editors likely special case - # `@property` and we don't want to cause issues just to have - # more helpful internal types. - - class typed_cached_property(Generic[_T]): # noqa: N801 - func: Callable[[Any], _T] - attrname: str | None - - def __init__(self, func: Callable[[Any], _T]) -> None: ... - - @overload - def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... - - @overload - def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... - - def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: - raise NotImplementedError() - - def __set_name__(self, owner: type[Any], name: str) -> None: ... - - # __set__ is not defined at runtime, but @cached_property is designed to be settable - def __set__(self, instance: object, value: _T) -> None: ... -else: - try: - from functools import cached_property - except ImportError: - from cached_property import cached_property - - typed_cached_property = cached_property diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py deleted file mode 100644 index 69b1d3a83dfef3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py +++ /dev/null @@ -1,670 +0,0 @@ -from __future__ import annotations - -import inspect -import os -from collections.abc import Callable -from datetime import date, datetime -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast - -import pydantic -import pydantic.generics -from pydantic.fields import FieldInfo -from typing_extensions import ( - ParamSpec, - Protocol, - override, - runtime_checkable, -) - -from ._base_compat import ( - PYDANTIC_V2, - ConfigDict, - field_get_default, - get_args, - get_model_config, - get_model_fields, - get_origin, - is_literal_type, - is_union, - parse_obj, -) -from ._base_compat import ( - GenericModel as BaseGenericModel, -) -from ._base_type import ( - IncEx, - ModelT, -) -from ._utils import ( - PropertyInfo, - coerce_boolean, - extract_type_arg, - is_annotated_type, - is_list, - is_mapping, - parse_date, - parse_datetime, - strip_annotated_type, -) - -if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField - -__all__ = ["BaseModel", "GenericModel"] -_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel") - -_T = TypeVar("_T") -P = ParamSpec("P") - - -@runtime_checkable -class _ConfigProtocol(Protocol): - allow_population_by_field_name: bool - - -class BaseModel(pydantic.BaseModel): - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) - ) - else: - - @property - @override - def model_fields_set(self) -> set[str]: - # a forwards-compat shim for pydantic v2 - return self.__fields_set__ # type: ignore - - class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] - extra: Any = pydantic.Extra.allow # type: ignore - - def to_dict( - self, - *, - mode: Literal["json", "python"] = "python", - use_api_names: bool = True, - exclude_unset: bool = True, - exclude_defaults: bool = False, - exclude_none: bool = False, - warnings: bool = True, - ) -> dict[str, object]: - """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. - - By default, fields that were not set by the API will not be included, - and keys will match the API response, *not* the property names from the model. - - For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, - the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). - - Args: - mode: - If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. - If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` - - use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. - exclude_unset: Whether to exclude fields that have not been explicitly set. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. - """ # noqa: E501 - return self.model_dump( - mode=mode, - by_alias=use_api_names, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - warnings=warnings, - ) - - def to_json( - self, - *, - indent: int | None = 2, - use_api_names: bool = True, - exclude_unset: bool = True, - exclude_defaults: bool = False, - exclude_none: bool = False, - warnings: bool = True, - ) -> str: - """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). - - By default, fields that were not set by the API will not be included, - and keys will match the API response, *not* the property names from the model. - - For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, - the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). - - Args: - indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` - use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. - exclude_unset: Whether to exclude fields that have not been explicitly set. - exclude_defaults: Whether to exclude fields that have the default value. - exclude_none: Whether to exclude fields that have a value of `None`. - warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. - """ # noqa: E501 - return self.model_dump_json( - indent=indent, - by_alias=use_api_names, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - warnings=warnings, - ) - - @override - def __str__(self) -> str: - # mypy complains about an invalid self arg - return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] - - # Override the 'construct' method in a way that supports recursive parsing without validation. - # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. - @classmethod - @override - def construct( - cls: type[ModelT], - _fields_set: set[str] | None = None, - **values: object, - ) -> ModelT: - m = cls.__new__(cls) - fields_values: dict[str, object] = {} - - config = get_model_config(cls) - populate_by_name = ( - config.allow_population_by_field_name - if isinstance(config, _ConfigProtocol) - else config.get("populate_by_name") - ) - - if _fields_set is None: - _fields_set = set() - - model_fields = get_model_fields(cls) - for name, field in model_fields.items(): - key = field.alias - if key is None or (key not in values and populate_by_name): - key = name - - if key in values: - fields_values[name] = _construct_field(value=values[key], field=field, key=key) - _fields_set.add(name) - else: - fields_values[name] = field_get_default(field) - - _extra = {} - for key, value in values.items(): - if key not in model_fields: - if PYDANTIC_V2: - _extra[key] = value - else: - _fields_set.add(key) - fields_values[key] = value - - object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801 - - if PYDANTIC_V2: - # these properties are copied from Pydantic's `model_construct()` method - object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801 - object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801 - object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801 - else: - # init_private_attributes() does not exist in v2 - m._init_private_attributes() # type: ignore - - # copied from Pydantic v1's `construct()` method - object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801 - - return m - - if not TYPE_CHECKING: - # type checkers incorrectly complain about this assignment - # because the type signatures are technically different - # although not in practice - model_construct = construct - - if not PYDANTIC_V2: - # we define aliases for some of the new pydantic v2 methods so - # that we can just document these methods without having to specify - # a specific pydantic version as some users may not know which - # pydantic version they are currently using - - @override - def model_dump( - self, - *, - mode: Literal["json", "python"] | str = "python", - include: IncEx = None, - exclude: IncEx = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, - serialize_as_any: bool = False, - ) -> dict[str, Any]: - """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump - - Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. - - Args: - mode: The mode in which `to_python` should run. - If mode is 'json', the dictionary will only contain JSON serializable types. - If mode is 'python', the dictionary may contain any Python objects. - include: A list of fields to include in the output. - exclude: A list of fields to exclude from the output. - by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that are unset or None from the output. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - round_trip: Whether to enable serialization and deserialization round-trip support. - warnings: Whether to log warnings when invalid fields are encountered. - - Returns: - A dictionary representation of the model. - """ - if mode != "python": - raise ValueError("mode is only supported in Pydantic v2") - if round_trip != False: - raise ValueError("round_trip is only supported in Pydantic v2") - if warnings != True: - raise ValueError("warnings is only supported in Pydantic v2") - if context is not None: - raise ValueError("context is only supported in Pydantic v2") - if serialize_as_any != False: - raise ValueError("serialize_as_any is only supported in Pydantic v2") - return super().dict( # pyright: ignore[reportDeprecated] - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - @override - def model_dump_json( - self, - *, - indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, - serialize_as_any: bool = False, - ) -> str: - """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json - - Generates a JSON representation of the model using Pydantic's `to_json` method. - - Args: - indent: Indentation to use in the JSON output. If None is passed, the output will be compact. - include: Field(s) to include in the JSON output. Can take either a string or set of strings. - exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. - by_alias: Whether to serialize using field aliases. - exclude_unset: Whether to exclude fields that have not been explicitly set. - exclude_defaults: Whether to exclude fields that have the default value. - exclude_none: Whether to exclude fields that have a value of `None`. - round_trip: Whether to use serialization/deserialization between JSON and class instance. - warnings: Whether to show any warnings that occurred during serialization. - - Returns: - A JSON string representation of the model. - """ - if round_trip != False: - raise ValueError("round_trip is only supported in Pydantic v2") - if warnings != True: - raise ValueError("warnings is only supported in Pydantic v2") - if context is not None: - raise ValueError("context is only supported in Pydantic v2") - if serialize_as_any != False: - raise ValueError("serialize_as_any is only supported in Pydantic v2") - return super().json( # type: ignore[reportDeprecated] - indent=indent, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - -def _construct_field(value: object, field: FieldInfo, key: str) -> object: - if value is None: - return field_get_default(field) - - if PYDANTIC_V2: - type_ = field.annotation - else: - type_ = cast(type, field.outer_type_) # type: ignore - - if type_ is None: - raise RuntimeError(f"Unexpected field type is None for {key}") - - return construct_type(value=value, type_=type_) - - -def is_basemodel(type_: type) -> bool: - """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" - if is_union(type_): - return any(is_basemodel(variant) for variant in get_args(type_)) - - return is_basemodel_type(type_) - - -def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: - origin = get_origin(type_) or type_ - return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) - - -def build( - base_model_cls: Callable[P, _BaseModelT], - *args: P.args, - **kwargs: P.kwargs, -) -> _BaseModelT: - """Construct a BaseModel class without validation. - - This is useful for cases where you need to instantiate a `BaseModel` - from an API response as this provides type-safe params which isn't supported - by helpers like `construct_type()`. - - ```py - build(MyModel, my_field_a="foo", my_field_b=123) - ``` - """ - if args: - raise TypeError( - "Received positional arguments which are not supported; Keyword arguments must be used instead", - ) - - return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) - - -def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: - """Loose coercion to the expected type with construction of nested values. - - Note: the returned value from this function is not guaranteed to match the - given type. - """ - return cast(_T, construct_type(value=value, type_=type_)) - - -def construct_type(*, value: object, type_: type) -> object: - """Loose coercion to the expected type with construction of nested values. - - If the given value does not match the expected type then it is returned as-is. - """ - # we allow `object` as the input type because otherwise, passing things like - # `Literal['value']` will be reported as a type error by type checkers - type_ = cast("type[object]", type_) - - # unwrap `Annotated[T, ...]` -> `T` - if is_annotated_type(type_): - meta: tuple[Any, ...] = get_args(type_)[1:] - type_ = extract_type_arg(type_, 0) - else: - meta = () - # we need to use the origin class for any types that are subscripted generics - # e.g. Dict[str, object] - origin = get_origin(type_) or type_ - args = get_args(type_) - - if is_union(origin): - try: - return validate_type(type_=cast("type[object]", type_), value=value) - except Exception: - pass - - # if the type is a discriminated union then we want to construct the right variant - # in the union, even if the data doesn't match exactly, otherwise we'd break code - # that relies on the constructed class types, e.g. - # - # class FooType: - # kind: Literal['foo'] - # value: str - # - # class BarType: - # kind: Literal['bar'] - # value: int - # - # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then - # we'd end up constructing `FooType` when it should be `BarType`. - discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) - if discriminator and is_mapping(value): - variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) - if variant_value and isinstance(variant_value, str): - variant_type = discriminator.mapping.get(variant_value) - if variant_type: - return construct_type(type_=variant_type, value=value) - - # if the data is not valid, use the first variant that doesn't fail while deserializing - for variant in args: - try: - return construct_type(value=value, type_=variant) - except Exception: - continue - - raise RuntimeError(f"Could not convert data into a valid instance of {type_}") - if origin == dict: - if not is_mapping(value): - return value - - _, items_type = get_args(type_) # Dict[_, items_type] - return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} - - if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): - if is_list(value): - return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] - - if is_mapping(value): - if issubclass(type_, BaseModel): - return type_.construct(**value) # type: ignore[arg-type] - - return cast(Any, type_).construct(**value) - - if origin == list: - if not is_list(value): - return value - - inner_type = args[0] # List[inner_type] - return [construct_type(value=entry, type_=inner_type) for entry in value] - - if origin == float: - if isinstance(value, int): - coerced = float(value) - if coerced != value: - return value - return coerced - - return value - - if type_ == datetime: - try: - return parse_datetime(value) # type: ignore - except Exception: - return value - - if type_ == date: - try: - return parse_date(value) # type: ignore - except Exception: - return value - - return value - - -@runtime_checkable -class CachedDiscriminatorType(Protocol): - __discriminator__: DiscriminatorDetails - - -class DiscriminatorDetails: - field_name: str - """The name of the discriminator field in the variant class, e.g. - - ```py - class Foo(BaseModel): - type: Literal['foo'] - ``` - - Will result in field_name='type' - """ - - field_alias_from: str | None - """The name of the discriminator field in the API response, e.g. - - ```py - class Foo(BaseModel): - type: Literal['foo'] = Field(alias='type_from_api') - ``` - - Will result in field_alias_from='type_from_api' - """ - - mapping: dict[str, type] - """Mapping of discriminator value to variant type, e.g. - - {'foo': FooVariant, 'bar': BarVariant} - """ - - def __init__( - self, - *, - mapping: dict[str, type], - discriminator_field: str, - discriminator_alias: str | None, - ) -> None: - self.mapping = mapping - self.field_name = discriminator_field - self.field_alias_from = discriminator_alias - - -def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ - - discriminator_field_name: str | None = None - - for annotation in meta_annotations: - if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: - discriminator_field_name = annotation.discriminator - break - - if not discriminator_field_name: - return None - - mapping: dict[str, type] = {} - discriminator_alias: str | None = None - - for variant in get_args(union): - variant = strip_annotated_type(variant) - if is_basemodel_type(variant): - if PYDANTIC_V2: - field = _extract_field_schema_pv2(variant, discriminator_field_name) - if not field: - continue - - # Note: if one variant defines an alias then they all should - discriminator_alias = field.get("serialization_alias") - - field_schema = field["schema"] - - if field_schema["type"] == "literal": - for entry in cast("LiteralSchema", field_schema)["expected"]: - if isinstance(entry, str): - mapping[entry] = variant - else: - field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - if not field_info: - continue - - # Note: if one variant defines an alias then they all should - discriminator_alias = field_info.alias - - if field_info.annotation and is_literal_type(field_info.annotation): - for entry in get_args(field_info.annotation): - if isinstance(entry, str): - mapping[entry] = variant - - if not mapping: - return None - - details = DiscriminatorDetails( - mapping=mapping, - discriminator_field=discriminator_field_name, - discriminator_alias=discriminator_alias, - ) - cast(CachedDiscriminatorType, union).__discriminator__ = details - return details - - -def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: - schema = model.__pydantic_core_schema__ - if schema["type"] != "model": - return None - - fields_schema = schema["schema"] - if fields_schema["type"] != "model-fields": - return None - - fields_schema = cast("ModelFieldsSchema", fields_schema) - - field = fields_schema["fields"].get(field_name) - if not field: - return None - - return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] - - -def validate_type(*, type_: type[_T], value: object) -> _T: - """Strict validation that the given value matches the expected type""" - if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): - return cast(_T, parse_obj(type_, value)) - - return cast(_T, _validate_non_model_type(type_=type_, value=value)) - - -# Subclassing here confuses type checkers, so we treat this class as non-inheriting. -if TYPE_CHECKING: - GenericModel = BaseModel -else: - - class GenericModel(BaseGenericModel, BaseModel): - pass - - -if PYDANTIC_V2: - from pydantic import TypeAdapter - - def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: - return TypeAdapter(type_).validate_python(value) - -elif not TYPE_CHECKING: - - class TypeAdapter(Generic[_T]): - """Used as a placeholder to easily convert runtime types to a Pydantic format - to provide validation. - - For example: - ```py - validated = RootModel[int](__root__="5").__root__ - # validated: 5 - ``` - """ - - def __init__(self, type_: type[_T]): - self.type_ = type_ - - def validate_python(self, value: Any) -> _T: - if not isinstance(value, self.type_): - raise ValueError(f"Invalid type: {value} is not of type {self.type_}") - return value - - def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: - return TypeAdapter(type_).validate_python(value) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py deleted file mode 100644 index ea1d3f09dc42ea..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Mapping, Sequence -from os import PathLike -from typing import ( - IO, - TYPE_CHECKING, - Any, - Literal, - Optional, - TypeAlias, - TypeVar, - Union, -) - -import pydantic -from httpx import Response -from typing_extensions import Protocol, TypedDict, override, runtime_checkable - -Query = Mapping[str, object] -Body = object -AnyMapping = Mapping[str, object] -PrimitiveData = Union[str, int, float, bool, None] -Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] -ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) -_T = TypeVar("_T") - -if TYPE_CHECKING: - NoneType: type[None] -else: - NoneType = type(None) - - -# Sentinel class used until PEP 0661 is accepted -class NotGiven: - """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). - - For example: - - ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... - - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. - ``` - """ - - def __bool__(self) -> Literal[False]: - return False - - @override - def __repr__(self) -> str: - return "NOT_GIVEN" - - -NotGivenOr = Union[_T, NotGiven] -NOT_GIVEN = NotGiven() - - -class Omit: - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: - - ```py - # as the default `Content-Type` header is `application/json` that will be sent - client.post('/upload/files', files={'file': b'my raw file content'}) - - # you can't explicitly override the header as it has to be dynamically generated - # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' - client.post(..., headers={'Content-Type': 'multipart/form-data'}) - - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={'Content-Type': Omit()}) - ``` - """ - - def __bool__(self) -> Literal[False]: - return False - - -@runtime_checkable -class ModelBuilderProtocol(Protocol): - @classmethod - def build( - cls: type[_T], - *, - response: Response, - data: object, - ) -> _T: ... - - -Headers = Mapping[str, Union[str, Omit]] - - -class HeadersLikeProtocol(Protocol): - def get(self, __key: str) -> str | None: ... - - -HeadersLike = Union[Headers, HeadersLikeProtocol] - -ResponseT = TypeVar( - "ResponseT", - bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501 -) - -StrBytesIntFloat = Union[str, bytes, int, float] - -# Note: copied from Pydantic -# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 -IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" - -PostParser = Callable[[Any], Any] - - -@runtime_checkable -class InheritsGeneric(Protocol): - """Represents a type that has inherited from `Generic` - - The `__orig_bases__` property can be used to determine the resolved - type variable for a given base class. - """ - - __orig_bases__: tuple[_GenericAlias] - - -class _GenericAlias(Protocol): - __origin__: type[object] - - -class HttpxSendArgs(TypedDict, total=False): - auth: httpx.Auth - - -# for user input files -if TYPE_CHECKING: - Base64FileInput = Union[IO[bytes], PathLike[str]] - FileContent = Union[IO[bytes], bytes, PathLike[str]] -else: - Base64FileInput = Union[IO[bytes], PathLike] - FileContent = Union[IO[bytes], bytes, PathLike] - -FileTypes = Union[ - # file (or bytes) - FileContent, - # (filename, file (or bytes)) - tuple[Optional[str], FileContent], - # (filename, file (or bytes), content_type) - tuple[Optional[str], FileContent, Optional[str]], - # (filename, file (or bytes), content_type, headers) - tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], -] -RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] - -# duplicate of the above but without our custom file support -HttpxFileContent = Union[bytes, IO[bytes]] -HttpxFileTypes = Union[ - # file (or bytes) - HttpxFileContent, - # (filename, file (or bytes)) - tuple[Optional[str], HttpxFileContent], - # (filename, file (or bytes), content_type) - tuple[Optional[str], HttpxFileContent, Optional[str]], - # (filename, file (or bytes), content_type, headers) - tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], -] - -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py deleted file mode 100644 index 8e43bdebecb61f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py +++ /dev/null @@ -1,12 +0,0 @@ -import httpx - -RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" -# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0` -ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) -# 通过 `retry` 参数控制重试次数,默认为3次 -ZHIPUAI_DEFAULT_MAX_RETRIES = 3 -# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10` -ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) - -INITIAL_RETRY_DELAY = 0.5 -MAX_RETRY_DELAY = 8.0 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py deleted file mode 100644 index e2c9d24c6c0d24..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import httpx - -__all__ = [ - "ZhipuAIError", - "APIStatusError", - "APIRequestFailedError", - "APIAuthenticationError", - "APIReachLimitError", - "APIInternalError", - "APIServerFlowExceedError", - "APIResponseError", - "APIResponseValidationError", - "APITimeoutError", - "APIConnectionError", -] - - -class ZhipuAIError(Exception): - def __init__( - self, - message: str, - ) -> None: - super().__init__(message) - - -class APIStatusError(ZhipuAIError): - response: httpx.Response - status_code: int - - def __init__(self, message: str, *, response: httpx.Response) -> None: - super().__init__(message) - self.response = response - self.status_code = response.status_code - - -class APIRequestFailedError(APIStatusError): ... - - -class APIAuthenticationError(APIStatusError): ... - - -class APIReachLimitError(APIStatusError): ... - - -class APIInternalError(APIStatusError): ... - - -class APIServerFlowExceedError(APIStatusError): ... - - -class APIResponseError(ZhipuAIError): - message: str - request: httpx.Request - json_data: object - - def __init__(self, message: str, request: httpx.Request, json_data: object): - self.message = message - self.request = request - self.json_data = json_data - super().__init__(message) - - -class APIResponseValidationError(APIResponseError): - status_code: int - response: httpx.Response - - def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: - super().__init__( - message=message or "Data returned by API invalid for expected schema.", - request=response.request, - json_data=json_data, - ) - self.response = response - self.status_code = response.status_code - - -class APIConnectionError(APIResponseError): - def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: - super().__init__(message, request, json_data=None) - - -class APITimeoutError(APIConnectionError): - def __init__(self, request: httpx.Request) -> None: - super().__init__(message="Request timed out.", request=request) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py deleted file mode 100644 index f9d2e14d9ecb93..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -import io -import os -import pathlib -from typing import TypeGuard, overload - -from ._base_type import ( - Base64FileInput, - FileContent, - FileTypes, - HttpxFileContent, - HttpxFileTypes, - HttpxRequestFiles, - RequestFiles, -) -from ._utils import is_mapping_t, is_sequence_t, is_tuple_t - - -def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: - return isinstance(obj, io.IOBase | os.PathLike) - - -def is_file_content(obj: object) -> TypeGuard[FileContent]: - return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) - - -def assert_is_file_content(obj: object, *, key: str | None = None) -> None: - if not is_file_content(obj): - prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" - raise RuntimeError( - f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" - ) from None - - -@overload -def to_httpx_files(files: None) -> None: ... - - -@overload -def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... - - -def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: - if files is None: - return None - - if is_mapping_t(files): - files = {key: _transform_file(file) for key, file in files.items()} - elif is_sequence_t(files): - files = [(key, _transform_file(file)) for key, file in files] - else: - raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") - - return files - - -def _transform_file(file: FileTypes) -> HttpxFileTypes: - if is_file_content(file): - if isinstance(file, os.PathLike): - path = pathlib.Path(file) - return (path.name, path.read_bytes()) - - return file - - if is_tuple_t(file): - return (file[0], _read_file_content(file[1]), *file[2:]) - - raise TypeError("Expected file types input to be a FileContent type or to be a tuple") - - -def _read_file_content(file: FileContent) -> HttpxFileContent: - if isinstance(file, os.PathLike): - return pathlib.Path(file).read_bytes() - return file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py deleted file mode 100644 index ffdafb85d581fe..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ /dev/null @@ -1,910 +0,0 @@ -from __future__ import annotations - -import inspect -import logging -import time -import warnings -from collections.abc import Iterator, Mapping -from itertools import starmap -from random import random -from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload - -import httpx -import pydantic -from httpx import URL, Timeout - -from . import _errors, get_origin -from ._base_compat import model_copy -from ._base_models import GenericModel, construct_type, validate_type -from ._base_type import ( - NOT_GIVEN, - AnyMapping, - Body, - Data, - Headers, - HttpxSendArgs, - ModelBuilderProtocol, - NotGiven, - Omit, - PostParser, - Query, - RequestFiles, - ResponseT, -) -from ._constants import ( - INITIAL_RETRY_DELAY, - MAX_RETRY_DELAY, - RAW_RESPONSE_HEADER, - ZHIPUAI_DEFAULT_LIMITS, - ZHIPUAI_DEFAULT_MAX_RETRIES, - ZHIPUAI_DEFAULT_TIMEOUT, -) -from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError -from ._files import to_httpx_files -from ._legacy_response import LegacyAPIResponse -from ._request_opt import FinalRequestOptions, UserRequestInput -from ._response import APIResponse, BaseAPIResponse, extract_response_type -from ._sse_client import StreamResponse -from ._utils import flatten, is_given, is_mapping - -log: logging.Logger = logging.getLogger(__name__) - -# TODO: make base page type vars covariant -SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") -# AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") - -_T = TypeVar("_T") -_T_co = TypeVar("_T_co", covariant=True) - -if TYPE_CHECKING: - from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT -else: - try: - from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT - except ImportError: - # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366 - HTTPX_DEFAULT_TIMEOUT = Timeout(5.0) - - -headers = { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", -} - - -class PageInfo: - """Stores the necessary information to build the request to retrieve the next page. - - Either `url` or `params` must be set. - """ - - url: URL | NotGiven - params: Query | NotGiven - - @overload - def __init__( - self, - *, - url: URL, - ) -> None: ... - - @overload - def __init__( - self, - *, - params: Query, - ) -> None: ... - - def __init__( - self, - *, - url: URL | NotGiven = NOT_GIVEN, - params: Query | NotGiven = NOT_GIVEN, - ) -> None: - self.url = url - self.params = params - - -class BasePage(GenericModel, Generic[_T]): - """ - Defines the core interface for pagination. - - Type Args: - ModelT: The pydantic model that represents an item in the response. - - Methods: - has_next_page(): Check if there is another page available - next_page_info(): Get the necessary information to make a request for the next page - """ - - _options: FinalRequestOptions = pydantic.PrivateAttr() - _model: type[_T] = pydantic.PrivateAttr() - - def has_next_page(self) -> bool: - items = self._get_page_items() - if not items: - return False - return self.next_page_info() is not None - - def next_page_info(self) -> Optional[PageInfo]: ... - - def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] - ... - - def _params_from_url(self, url: URL) -> httpx.QueryParams: - # TODO: do we have to preprocess params here? - return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params) - - def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: - options = model_copy(self._options) - options._strip_raw_response_header() - - if not isinstance(info.params, NotGiven): - options.params = {**options.params, **info.params} - return options - - if not isinstance(info.url, NotGiven): - params = self._params_from_url(info.url) - url = info.url.copy_with(params=params) - options.params = dict(url.params) - options.url = str(url) - return options - - raise ValueError("Unexpected PageInfo state") - - -class BaseSyncPage(BasePage[_T], Generic[_T]): - _client: HttpClient = pydantic.PrivateAttr() - - def _set_private_attributes( - self, - client: HttpClient, - model: type[_T], - options: FinalRequestOptions, - ) -> None: - self._model = model - self._client = client - self._options = options - - # Pydantic uses a custom `__iter__` method to support casting BaseModels - # to dictionaries. e.g. dict(model). - # As we want to support `for item in page`, this is inherently incompatible - # with the default pydantic behavior. It is not possible to support both - # use cases at once. Fortunately, this is not a big deal as all other pydantic - # methods should continue to work as expected as there is an alternative method - # to cast a model to a dictionary, model.dict(), which is used internally - # by pydantic. - def __iter__(self) -> Iterator[_T]: # type: ignore - for page in self.iter_pages(): - yield from page._get_page_items() - - def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]: - page = self - while True: - yield page - if page.has_next_page(): - page = page.get_next_page() - else: - return - - def get_next_page(self: SyncPageT) -> SyncPageT: - info = self.next_page_info() - if not info: - raise RuntimeError( - "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`." - ) - - options = self._info_to_options(info) - return self._client._request_api_list(self._model, page=self.__class__, options=options) - - -class HttpClient: - _client: httpx.Client - _version: str - _base_url: URL - max_retries: int - timeout: Union[float, Timeout, None] - _limits: httpx.Limits - _has_custom_http_client: bool - _default_stream_cls: type[StreamResponse[Any]] | None = None - - _strict_response_validation: bool - - def __init__( - self, - *, - version: str, - base_url: URL, - _strict_response_validation: bool, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - timeout: Union[float, Timeout, None], - limits: httpx.Limits | None = None, - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, - ) -> None: - if limits is not None: - warnings.warn( - "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501 - category=DeprecationWarning, - stacklevel=3, - ) - if custom_httpx_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") - else: - limits = ZHIPUAI_DEFAULT_LIMITS - - if not is_given(timeout): - if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: - timeout = custom_httpx_client.timeout - else: - timeout = ZHIPUAI_DEFAULT_TIMEOUT - self.max_retries = max_retries - self.timeout = timeout - self._limits = limits - self._has_custom_http_client = bool(custom_httpx_client) - self._client = custom_httpx_client or httpx.Client( - base_url=base_url, - timeout=self.timeout, - limits=limits, - ) - self._version = version - url = URL(url=base_url) - if not url.raw_path.endswith(b"/"): - url = url.copy_with(raw_path=url.raw_path + b"/") - self._base_url = url - self._custom_headers = custom_headers or {} - self._strict_response_validation = _strict_response_validation - - def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) - if sub_url.is_relative_url: - request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") - return self._base_url.copy_with(raw_path=request_raw_url) - - return sub_url - - @property - def _default_headers(self): - return { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self.auth_headers, - **self._custom_headers, - } - - @property - def custom_auth(self) -> httpx.Auth | None: - return None - - @property - def auth_headers(self): - return {} - - def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers: - custom_headers = options.headers or {} - headers_dict = _merge_mappings(self._default_headers, custom_headers) - - httpx_headers = httpx.Headers(headers_dict) - - return httpx_headers - - def _remaining_retries( - self, - remaining_retries: Optional[int], - options: FinalRequestOptions, - ) -> int: - return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) - - def _calculate_retry_timeout( - self, - remaining_retries: int, - options: FinalRequestOptions, - response_headers: Optional[httpx.Headers] = None, - ) -> float: - max_retries = options.get_max_retries(self.max_retries) - - # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. - # retry_after = self._parse_retry_after_header(response_headers) - # if retry_after is not None and 0 < retry_after <= 60: - # return retry_after - - nb_retries = max_retries - remaining_retries - - # Apply exponential backoff, but not more than the max. - sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) - - # Apply some jitter, plus-or-minus half a second. - jitter = 1 - 0.25 * random() - timeout = sleep_seconds * jitter - return max(timeout, 0) - - def _build_request(self, options: FinalRequestOptions) -> httpx.Request: - kwargs: dict[str, Any] = {} - headers = self._prepare_headers(options) - url = self._prepare_url(options.url) - json_data = options.json_data - if options.extra_json is not None: - if json_data is None: - json_data = cast(Body, options.extra_json) - elif is_mapping(json_data): - json_data = _merge_mappings(json_data, options.extra_json) - else: - raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") - - content_type = headers.get("Content-Type") - # multipart/form-data; boundary=---abc-- - if headers.get("Content-Type") == "multipart/form-data": - if "boundary" not in content_type: - # only remove the header if the boundary hasn't been explicitly set - # as the caller doesn't want httpx to come up with their own boundary - headers.pop("Content-Type") - - if json_data: - kwargs["data"] = self._make_multipartform(json_data) - - return self._client.build_request( - headers=headers, - timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, - method=options.method, - url=url, - json=json_data, - files=options.files, - params=options.params, - **kwargs, - ) - - def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: - items = [] - - if isinstance(value, Mapping): - for k, v in value.items(): - items.extend(self._object_to_formdata(f"{key}[{k}]", v)) - return items - if isinstance(value, list | tuple): - for v in value: - items.extend(self._object_to_formdata(key + "[]", v)) - return items - - def _primitive_value_to_str(val) -> str: - # copied from httpx - if val is True: - return "true" - elif val is False: - return "false" - elif val is None: - return "" - return str(val) - - str_data = _primitive_value_to_str(value) - - if not str_data: - return [] - return [(key, str_data)] - - def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten(list(starmap(self._object_to_formdata, data.items()))) - - serialized: dict[str, object] = {} - for key, value in items: - if key in serialized: - raise ValueError(f"存在重复的键: {key};") - serialized[key] = value - return serialized - - def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, - ) -> ResponseT: - if data is None: - return cast(ResponseT, None) - - if cast_type is object: - return cast(ResponseT, data) - - try: - if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol): - return cast(ResponseT, cast_type.build(response=response, data=data)) - - if self._strict_response_validation: - return cast(ResponseT, validate_type(type_=cast_type, value=data)) - - return cast(ResponseT, construct_type(type_=cast_type, value=data)) - except pydantic.ValidationError as err: - raise APIResponseValidationError(response=response, json_data=data) from err - - def _should_stream_response_body(self, request: httpx.Request) -> bool: - return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] - - def _should_retry(self, response: httpx.Response) -> bool: - # Note: this is not a standard header - should_retry_header = response.headers.get("x-should-retry") - - # If the server explicitly says whether or not to retry, obey. - if should_retry_header == "true": - log.debug("Retrying as header `x-should-retry` is set to `true`") - return True - if should_retry_header == "false": - log.debug("Not retrying as header `x-should-retry` is set to `false`") - return False - - # Retry on request timeouts. - if response.status_code == 408: - log.debug("Retrying due to status code %i", response.status_code) - return True - - # Retry on lock timeouts. - if response.status_code == 409: - log.debug("Retrying due to status code %i", response.status_code) - return True - - # Retry on rate limits. - if response.status_code == 429: - log.debug("Retrying due to status code %i", response.status_code) - return True - - # Retry internal errors. - if response.status_code >= 500: - log.debug("Retrying due to status code %i", response.status_code) - return True - - log.debug("Not retrying") - return False - - def is_closed(self) -> bool: - return self._client.is_closed - - def close(self): - self._client.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def request( - self, - cast_type: type[ResponseT], - options: FinalRequestOptions, - remaining_retries: Optional[int] = None, - *, - stream: bool = False, - stream_cls: type[StreamResponse] | None = None, - ) -> ResponseT | StreamResponse: - return self._request( - cast_type=cast_type, - options=options, - stream=stream, - stream_cls=stream_cls, - remaining_retries=remaining_retries, - ) - - def _request( - self, - *, - cast_type: type[ResponseT], - options: FinalRequestOptions, - remaining_retries: int | None, - stream: bool, - stream_cls: type[StreamResponse] | None, - ) -> ResponseT | StreamResponse: - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) - - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth - try: - response = self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) - - if retries > 0: - return self._retry_request( - options, - cast_type, - retries, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) - - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) - - if retries > 0: - return self._retry_request( - options, - cast_type, - retries, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) - - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err - - log.debug( - 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase - ) - - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if retries > 0 and self._should_retry(err.response): - err.response.close() - return self._retry_request( - options, - cast_type, - retries, - err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) - - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - err.response.read() - - log.debug("Re-raising status error") - raise self._make_status_error(err.response) from None - - # return self._parse_response( - # cast_type=cast_type, - # options=options, - # response=response, - # stream=stream, - # stream_cls=stream_cls, - # ) - return self._process_response( - cast_type=cast_type, - options=options, - response=response, - stream=stream, - stream_cls=stream_cls, - ) - - def _retry_request( - self, - options: FinalRequestOptions, - cast_type: type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, - *, - stream: bool, - stream_cls: type[StreamResponse] | None, - ) -> ResponseT | StreamResponse: - remaining = remaining_retries - 1 - if remaining == 1: - log.debug("1 retry left") - else: - log.debug("%i retries left", remaining) - - timeout = self._calculate_retry_timeout(remaining, options, response_headers) - log.info("Retrying request to %s in %f seconds", options.url, timeout) - - # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a - # different thread if necessary. - time.sleep(timeout) - - return self._request( - options=options, - cast_type=cast_type, - remaining_retries=remaining, - stream=stream, - stream_cls=stream_cls, - ) - - def _process_response( - self, - *, - cast_type: type[ResponseT], - options: FinalRequestOptions, - response: httpx.Response, - stream: bool, - stream_cls: type[StreamResponse] | None, - ) -> ResponseT: - # _legacy_response with raw_response_header to parser method - if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": - return cast( - ResponseT, - LegacyAPIResponse( - raw=response, - client=self, - cast_type=cast_type, - stream=stream, - stream_cls=stream_cls, - options=options, - ), - ) - - origin = get_origin(cast_type) or cast_type - - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): - if not issubclass(origin, APIResponse): - raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") - - response_cls = cast("type[BaseAPIResponse[Any]]", cast_type) - return cast( - ResponseT, - response_cls( - raw=response, - client=self, - cast_type=extract_response_type(response_cls), - stream=stream, - stream_cls=stream_cls, - options=options, - ), - ) - - if cast_type == httpx.Response: - return cast(ResponseT, response) - - api_response = APIResponse( - raw=response, - client=self, - cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast] - stream=stream, - stream_cls=stream_cls, - options=options, - ) - if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): - return cast(ResponseT, api_response) - - return api_response.parse() - - def _request_api_list( - self, - model: type[object], - page: type[SyncPageT], - options: FinalRequestOptions, - ) -> SyncPageT: - def _parser(resp: SyncPageT) -> SyncPageT: - resp._set_private_attributes( - client=self, - model=model, - options=options, - ) - return resp - - options.post_parser = _parser - - return self.request(page, options, stream=False) - - @overload - def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - stream: Literal[False] = False, - ) -> ResponseT: ... - - @overload - def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - stream: Literal[True], - stream_cls: type[StreamResponse], - ) -> StreamResponse: ... - - @overload - def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - stream: bool, - stream_cls: type[StreamResponse] | None = None, - ) -> ResponseT | StreamResponse: ... - - def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - stream: bool = False, - stream_cls: type[StreamResponse] | None = None, - ) -> ResponseT: - opts = FinalRequestOptions.construct(method="get", url=path, **options) - return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) - - @overload - def post( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - files: RequestFiles | None = None, - stream: Literal[False] = False, - ) -> ResponseT: ... - - @overload - def post( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - files: RequestFiles | None = None, - stream: Literal[True], - stream_cls: type[StreamResponse], - ) -> StreamResponse: ... - - @overload - def post( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - files: RequestFiles | None = None, - stream: bool, - stream_cls: type[StreamResponse] | None = None, - ) -> ResponseT | StreamResponse: ... - - def post( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - files: RequestFiles | None = None, - stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> ResponseT | StreamResponse: - opts = FinalRequestOptions.construct( - method="post", url=path, json_data=body, files=to_httpx_files(files), **options - ) - - return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) - - def patch( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) - - return self.request( - cast_type=cast_type, - options=opts, - ) - - def put( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - files: RequestFiles | None = None, - ) -> ResponseT | StreamResponse: - opts = FinalRequestOptions.construct( - method="put", url=path, json_data=body, files=to_httpx_files(files), **options - ) - - return self.request( - cast_type=cast_type, - options=opts, - ) - - def delete( - self, - path: str, - *, - cast_type: type[ResponseT], - body: Body | None = None, - options: UserRequestInput = {}, - ) -> ResponseT | StreamResponse: - opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options) - - return self.request( - cast_type=cast_type, - options=opts, - ) - - def get_api_list( - self, - path: str, - *, - model: type[object], - page: type[SyncPageT], - body: Body | None = None, - options: UserRequestInput = {}, - method: str = "get", - ) -> SyncPageT: - opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options) - return self._request_api_list(model, page, opts) - - def _make_status_error(self, response) -> APIStatusError: - response_text = response.text.strip() - status_code = response.status_code - error_msg = f"Error code: {status_code}, with error text {response_text}" - - if status_code == 400: - return _errors.APIRequestFailedError(message=error_msg, response=response) - elif status_code == 401: - return _errors.APIAuthenticationError(message=error_msg, response=response) - elif status_code == 429: - return _errors.APIReachLimitError(message=error_msg, response=response) - elif status_code == 500: - return _errors.APIInternalError(message=error_msg, response=response) - elif status_code == 503: - return _errors.APIServerFlowExceedError(message=error_msg, response=response) - return APIStatusError(message=error_msg, response=response) - - -def make_request_options( - *, - query: Query | None = None, - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - post_parser: PostParser | NotGiven = NOT_GIVEN, -) -> UserRequestInput: - """Create a dict of type RequestOptions without keys of NotGiven values.""" - options: UserRequestInput = {} - if extra_headers is not None: - options["headers"] = extra_headers - - if extra_body is not None: - options["extra_json"] = cast(AnyMapping, extra_body) - - if query is not None: - options["params"] = query - - if extra_query is not None: - options["params"] = {**options.get("params", {}), **extra_query} - - if not isinstance(timeout, NotGiven): - options["timeout"] = timeout - - if is_given(post_parser): - # internal - options["post_parser"] = post_parser # type: ignore - - return options - - -def _merge_mappings( - obj1: Mapping[_T_co, Union[_T, Omit]], - obj2: Mapping[_T_co, Union[_T, Omit]], -) -> dict[_T_co, _T]: - """Merge two mappings of the same type, removing any values that are instances of `Omit`. - - In cases with duplicate keys the second mapping takes precedence. - """ - merged = {**obj1, **obj2} - return {key: value for key, value in merged.items() if not isinstance(value, Omit)} diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py deleted file mode 100644 index 21f158a5f45251..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ /dev/null @@ -1,31 +0,0 @@ -import time - -import cachetools.func -import jwt - -# 缓存时间 3分钟 -CACHE_TTL_SECONDS = 3 * 60 - -# token 有效期比缓存时间 多30秒 -API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30 - - -@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) -def generate_token(apikey: str): - try: - api_key, secret = apikey.split(".") - except Exception as e: - raise Exception("invalid api_key", e) - - payload = { - "api_key": api_key, - "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, - "timestamp": int(round(time.time() * 1000)), - } - ret = jwt.encode( - payload, - secret, - algorithm="HS256", - headers={"alg": "HS256", "sign_type": "SIGN"}, - ) - return ret diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py deleted file mode 100644 index 51623bd860951f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py +++ /dev/null @@ -1,207 +0,0 @@ -from __future__ import annotations - -import os -from collections.abc import AsyncIterator, Iterator -from typing import Any - -import httpx - - -class HttpxResponseContent: - @property - def content(self) -> bytes: - raise NotImplementedError("This method is not implemented for this class.") - - @property - def text(self) -> str: - raise NotImplementedError("This method is not implemented for this class.") - - @property - def encoding(self) -> str | None: - raise NotImplementedError("This method is not implemented for this class.") - - @property - def charset_encoding(self) -> str | None: - raise NotImplementedError("This method is not implemented for this class.") - - def json(self, **kwargs: Any) -> Any: - raise NotImplementedError("This method is not implemented for this class.") - - def read(self) -> bytes: - raise NotImplementedError("This method is not implemented for this class.") - - def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: - raise NotImplementedError("This method is not implemented for this class.") - - def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: - raise NotImplementedError("This method is not implemented for this class.") - - def iter_lines(self) -> Iterator[str]: - raise NotImplementedError("This method is not implemented for this class.") - - def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: - raise NotImplementedError("This method is not implemented for this class.") - - def write_to_file( - self, - file: str | os.PathLike[str], - ) -> None: - raise NotImplementedError("This method is not implemented for this class.") - - def stream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - raise NotImplementedError("This method is not implemented for this class.") - - def close(self) -> None: - raise NotImplementedError("This method is not implemented for this class.") - - async def aread(self) -> bytes: - raise NotImplementedError("This method is not implemented for this class.") - - async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: - raise NotImplementedError("This method is not implemented for this class.") - - async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: - raise NotImplementedError("This method is not implemented for this class.") - - async def aiter_lines(self) -> AsyncIterator[str]: - raise NotImplementedError("This method is not implemented for this class.") - - async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: - raise NotImplementedError("This method is not implemented for this class.") - - async def astream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - raise NotImplementedError("This method is not implemented for this class.") - - async def aclose(self) -> None: - raise NotImplementedError("This method is not implemented for this class.") - - -class HttpxBinaryResponseContent(HttpxResponseContent): - response: httpx.Response - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - @property - def content(self) -> bytes: - return self.response.content - - @property - def encoding(self) -> str | None: - return self.response.encoding - - @property - def charset_encoding(self) -> str | None: - return self.response.charset_encoding - - def read(self) -> bytes: - return self.response.read() - - def text(self) -> str: - raise NotImplementedError("Not implemented for binary response content") - - def json(self, **kwargs: Any) -> Any: - raise NotImplementedError("Not implemented for binary response content") - - def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: - raise NotImplementedError("Not implemented for binary response content") - - def iter_lines(self) -> Iterator[str]: - raise NotImplementedError("Not implemented for binary response content") - - async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: - raise NotImplementedError("Not implemented for binary response content") - - async def aiter_lines(self) -> AsyncIterator[str]: - raise NotImplementedError("Not implemented for binary response content") - - def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: - return self.response.iter_bytes(chunk_size) - - def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: - return self.response.iter_raw(chunk_size) - - def write_to_file( - self, - file: str | os.PathLike[str], - ) -> None: - """Write the output to the given file. - - Accepts a filename or any path-like object, e.g. pathlib.Path - - Note: if you want to stream the data to the file instead of writing - all at once then you should use `.with_streaming_response` when making - the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')` - """ - with open(file, mode="wb") as f: - for data in self.response.iter_bytes(): - f.write(data) - - def stream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - with open(file, mode="wb") as f: - for data in self.response.iter_bytes(chunk_size): - f.write(data) - - def close(self) -> None: - return self.response.close() - - async def aread(self) -> bytes: - return await self.response.aread() - - async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: - return self.response.aiter_bytes(chunk_size) - - async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: - return self.response.aiter_raw(chunk_size) - - async def astream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - path = anyio.Path(file) - async with await path.open(mode="wb") as f: - async for data in self.response.aiter_bytes(chunk_size): - await f.write(data) - - async def aclose(self) -> None: - return await self.response.aclose() - - -class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent): - response: httpx.Response - - @property - def text(self) -> str: - return self.response.text - - def json(self, **kwargs: Any) -> Any: - return self.response.json(**kwargs) - - def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: - return self.response.iter_text(chunk_size) - - def iter_lines(self) -> Iterator[str]: - return self.response.iter_lines() - - async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: - return self.response.aiter_text(chunk_size) - - async def aiter_lines(self) -> AsyncIterator[str]: - return self.response.aiter_lines() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py deleted file mode 100644 index 51bf21bcdc17a8..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py +++ /dev/null @@ -1,341 +0,0 @@ -from __future__ import annotations - -import datetime -import functools -import inspect -import logging -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload - -import httpx -import pydantic -from typing_extensions import ParamSpec, override - -from ._base_models import BaseModel, is_basemodel -from ._base_type import NoneType -from ._constants import RAW_RESPONSE_HEADER -from ._errors import APIResponseValidationError -from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent -from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type -from ._utils import extract_type_arg, is_annotated_type, is_given - -if TYPE_CHECKING: - from ._http_client import HttpClient - from ._request_opt import FinalRequestOptions - -P = ParamSpec("P") -R = TypeVar("R") -_T = TypeVar("_T") - -log: logging.Logger = logging.getLogger(__name__) - - -class LegacyAPIResponse(Generic[R]): - """This is a legacy class as it will be replaced by `APIResponse` - and `AsyncAPIResponse` in the `_response.py` file in the next major - release. - - For the sync client this will mostly be the same with the exception - of `content` & `text` will be methods instead of properties. In the - async client, all methods will be async. - - A migration script will be provided & the migration in general should - be smooth. - """ - - _cast_type: type[R] - _client: HttpClient - _parsed_by_type: dict[type[Any], Any] - _stream: bool - _stream_cls: type[StreamResponse[Any]] | None - _options: FinalRequestOptions - - http_response: httpx.Response - - def __init__( - self, - *, - raw: httpx.Response, - cast_type: type[R], - client: HttpClient, - stream: bool, - stream_cls: type[StreamResponse[Any]] | None, - options: FinalRequestOptions, - ) -> None: - self._cast_type = cast_type - self._client = client - self._parsed_by_type = {} - self._stream = stream - self._stream_cls = stream_cls - self._options = options - self.http_response = raw - - @property - def request_id(self) -> str | None: - return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] - - @overload - def parse(self, *, to: type[_T]) -> _T: ... - - @overload - def parse(self) -> R: ... - - def parse(self, *, to: type[_T] | None = None) -> R | _T: - """Returns the rich python representation of this response's data. - - NOTE: For the async client: this will become a coroutine in the next major version. - - For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. - - You can customize the type that the response is parsed into through - the `to` argument, e.g. - - ```py - from zhipuai import BaseModel - - - class MyModel(BaseModel): - foo: str - - - obj = response.parse(to=MyModel) - print(obj.foo) - ``` - - We support parsing: - - `BaseModel` - - `dict` - - `list` - - `Union` - - `str` - - `int` - - `float` - - `httpx.Response` - """ - cache_key = to if to is not None else self._cast_type - cached = self._parsed_by_type.get(cache_key) - if cached is not None: - return cached # type: ignore[no-any-return] - - parsed = self._parse(to=to) - if is_given(self._options.post_parser): - parsed = self._options.post_parser(parsed) - - self._parsed_by_type[cache_key] = parsed - return parsed - - @property - def headers(self) -> httpx.Headers: - return self.http_response.headers - - @property - def http_request(self) -> httpx.Request: - return self.http_response.request - - @property - def status_code(self) -> int: - return self.http_response.status_code - - @property - def url(self) -> httpx.URL: - return self.http_response.url - - @property - def method(self) -> str: - return self.http_request.method - - @property - def content(self) -> bytes: - """Return the binary response content. - - NOTE: this will be removed in favour of `.read()` in the - next major version. - """ - return self.http_response.content - - @property - def text(self) -> str: - """Return the decoded response content. - - NOTE: this will be turned into a method in the next major version. - """ - return self.http_response.text - - @property - def http_version(self) -> str: - return self.http_response.http_version - - @property - def is_closed(self) -> bool: - return self.http_response.is_closed - - @property - def elapsed(self) -> datetime.timedelta: - """The time taken for the complete request/response cycle to complete.""" - return self.http_response.elapsed - - def _parse(self, *, to: type[_T] | None = None) -> R | _T: - # unwrap `Annotated[T, ...]` -> `T` - if to and is_annotated_type(to): - to = extract_type_arg(to, 0) - - if self._stream: - if to: - if not is_stream_class_type(to): - raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") - - return cast( - _T, - to( - cast_type=extract_stream_chunk_type( - to, - failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 - ), - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - if self._stream_cls: - return cast( - R, - self._stream_cls( - cast_type=extract_stream_chunk_type(self._stream_cls), - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls) - if stream_cls is None: - raise MissingStreamClassError() - - return cast( - R, - stream_cls( - cast_type=self._cast_type, - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - cast_type = to if to is not None else self._cast_type - - # unwrap `Annotated[T, ...]` -> `T` - if is_annotated_type(cast_type): - cast_type = extract_type_arg(cast_type, 0) - - if cast_type is NoneType: - return cast(R, None) - - response = self.http_response - if cast_type == str: - return cast(R, response.text) - - if cast_type == int: - return cast(R, int(response.text)) - - if cast_type == float: - return cast(R, float(response.text)) - - origin = get_origin(cast_type) or cast_type - - if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent): - # in the response, e.g. mime file - *_, filename = response.headers.get("content-disposition", "").split("filename=") - # 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent - if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"): - return cast(R, HttpxTextBinaryResponseContent(response)) - else: - return cast(R, cast_type(response)) # type: ignore - - if origin == LegacyAPIResponse: - raise RuntimeError("Unexpected state - cast_type is `APIResponse`") - - if inspect.isclass(origin) and issubclass(origin, httpx.Response): - # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response - # and pass that class to our request functions. We cannot change the variance to be either - # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct - # the response class ourselves but that is something that should be supported directly in httpx - # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. - if cast_type != httpx.Response: - raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") - return cast(R, response) - - if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): - raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") - - if ( - cast_type is not object - and origin is not list - and origin is not dict - and origin is not Union - and not issubclass(origin, BaseModel) - ): - raise RuntimeError( - f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 - ) - - # split is required to handle cases where additional information is included - # in the response, e.g. application/json; charset=utf-8 - content_type, *_ = response.headers.get("content-type", "*").split(";") - if content_type != "application/json": - if is_basemodel(cast_type): - try: - data = response.json() - except Exception as exc: - log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) - else: - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=response, - ) - - if self._client._strict_response_validation: - raise APIResponseValidationError( - response=response, - message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 - json_data=response.text, - ) - - # If the API responds with content that isn't JSON then we just return - # the (decoded) text without performing any parsing so that you can still - # handle the response however you need to. - return response.text # type: ignore - - data = response.json() - - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=response, - ) - - @override - def __repr__(self) -> str: - return f"" - - -class MissingStreamClassError(TypeError): - def __init__(self) -> None: - super().__init__( - "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 - ) - - -def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]: - """Higher order function that takes one of our bound API methods and wraps it - to support returning the raw `APIResponse` object directly. - """ - - @functools.wraps(func) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: - extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} - extra_headers[RAW_RESPONSE_HEADER] = "true" - - kwargs["extra_headers"] = extra_headers - - return cast(LegacyAPIResponse[R], func(*args, **kwargs)) - - return wrapped diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py deleted file mode 100644 index c3b894b3a3d88f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, ClassVar, Union, cast - -import pydantic.generics -from httpx import Timeout -from typing_extensions import Required, TypedDict, Unpack, final - -from ._base_compat import PYDANTIC_V2, ConfigDict -from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query -from ._constants import RAW_RESPONSE_HEADER -from ._utils import is_given, strip_not_given - - -class UserRequestInput(TypedDict, total=False): - headers: Headers - max_retries: int - timeout: float | Timeout | None - params: Query - extra_json: AnyMapping - - -class FinalRequestOptionsInput(TypedDict, total=False): - method: Required[str] - url: Required[str] - params: Query - headers: Headers - max_retries: int - timeout: float | Timeout | None - files: HttpxRequestFiles | None - json_data: Body - extra_json: AnyMapping - - -@final -class FinalRequestOptions(pydantic.BaseModel): - method: str - url: str - params: Query = {} - headers: Union[Headers, NotGiven] = NotGiven() - max_retries: Union[int, NotGiven] = NotGiven() - timeout: Union[float, Timeout, None, NotGiven] = NotGiven() - files: Union[HttpxRequestFiles, None] = None - idempotency_key: Union[str, None] = None - post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() - - # It should be noted that we cannot use `json` here as that would override - # a BaseModel method in an incompatible fashion. - json_data: Union[Body, None] = None - extra_json: Union[AnyMapping, None] = None - - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - else: - - class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] - arbitrary_types_allowed: bool = True - - def get_max_retries(self, max_retries: int) -> int: - if isinstance(self.max_retries, NotGiven): - return max_retries - return self.max_retries - - def _strip_raw_response_header(self) -> None: - if not is_given(self.headers): - return - - if self.headers.get(RAW_RESPONSE_HEADER): - self.headers = {**self.headers} - self.headers.pop(RAW_RESPONSE_HEADER) - - # override the `construct` method so that we can run custom transformations. - # this is necessary as we don't want to do any actual runtime type checking - # (which means we can't use validators) but we do want to ensure that `NotGiven` - # values are not present - # - # type ignore required because we're adding explicit types to `**values` - @classmethod - def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> FinalRequestOptions: - kwargs: dict[str, Any] = { - # we unconditionally call `strip_not_given` on any value - # as it will just ignore any non-mapping types - key: strip_not_given(value) - for key, value in values.items() - } - if PYDANTIC_V2: - return super().model_construct(_fields_set, **kwargs) - return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] - - if not TYPE_CHECKING: - # type checkers incorrectly complain about this assignment - model_construct = construct diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py deleted file mode 100644 index 92e601805569f3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ /dev/null @@ -1,398 +0,0 @@ -from __future__ import annotations - -import datetime -import inspect -import logging -from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload - -import httpx -import pydantic -from typing_extensions import ParamSpec, override - -from ._base_models import BaseModel, is_basemodel -from ._base_type import NoneType -from ._errors import APIResponseValidationError, ZhipuAIError -from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type -from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given - -if TYPE_CHECKING: - from ._http_client import HttpClient - from ._request_opt import FinalRequestOptions - -P = ParamSpec("P") -R = TypeVar("R") -_T = TypeVar("_T") -_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") -log: logging.Logger = logging.getLogger(__name__) - - -class BaseAPIResponse(Generic[R]): - _cast_type: type[R] - _client: HttpClient - _parsed_by_type: dict[type[Any], Any] - _is_sse_stream: bool - _stream_cls: type[StreamResponse[Any]] - _options: FinalRequestOptions - http_response: httpx.Response - - def __init__( - self, - *, - raw: httpx.Response, - cast_type: type[R], - client: HttpClient, - stream: bool, - stream_cls: type[StreamResponse[Any]] | None = None, - options: FinalRequestOptions, - ) -> None: - self._cast_type = cast_type - self._client = client - self._parsed_by_type = {} - self._is_sse_stream = stream - self._stream_cls = stream_cls - self._options = options - self.http_response = raw - - def _parse(self, *, to: type[_T] | None = None) -> R | _T: - # unwrap `Annotated[T, ...]` -> `T` - if to and is_annotated_type(to): - to = extract_type_arg(to, 0) - - if self._is_sse_stream: - if to: - if not is_stream_class_type(to): - raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") - - return cast( - _T, - to( - cast_type=extract_stream_chunk_type( - to, - failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 - ), - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - if self._stream_cls: - return cast( - R, - self._stream_cls( - cast_type=extract_stream_chunk_type(self._stream_cls), - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls) - if stream_cls is None: - raise MissingStreamClassError() - - return cast( - R, - stream_cls( - cast_type=self._cast_type, - response=self.http_response, - client=cast(Any, self._client), - ), - ) - - cast_type = to if to is not None else self._cast_type - - # unwrap `Annotated[T, ...]` -> `T` - if is_annotated_type(cast_type): - cast_type = extract_type_arg(cast_type, 0) - - if cast_type is NoneType: - return cast(R, None) - - response = self.http_response - if cast_type == str: - return cast(R, response.text) - - if cast_type == bytes: - return cast(R, response.content) - - if cast_type == int: - return cast(R, int(response.text)) - - if cast_type == float: - return cast(R, float(response.text)) - - origin = get_origin(cast_type) or cast_type - - # handle the legacy binary response case - if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent": - return cast(R, cast_type(response)) # type: ignore - - if origin == APIResponse: - raise RuntimeError("Unexpected state - cast_type is `APIResponse`") - - if inspect.isclass(origin) and issubclass(origin, httpx.Response): - # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response - # and pass that class to our request functions. We cannot change the variance to be either - # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct - # the response class ourselves but that is something that should be supported directly in httpx - # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. - if cast_type != httpx.Response: - raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") - return cast(R, response) - - if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): - raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") - - if ( - cast_type is not object - and origin is not list - and origin is not dict - and origin is not Union - and not issubclass(origin, BaseModel) - ): - raise RuntimeError( - f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 - ) - - # split is required to handle cases where additional information is included - # in the response, e.g. application/json; charset=utf-8 - content_type, *_ = response.headers.get("content-type", "*").split(";") - if content_type != "application/json": - if is_basemodel(cast_type): - try: - data = response.json() - except Exception as exc: - log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) - else: - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=response, - ) - - if self._client._strict_response_validation: - raise APIResponseValidationError( - response=response, - message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 - json_data=response.text, - ) - - # If the API responds with content that isn't JSON then we just return - # the (decoded) text without performing any parsing so that you can still - # handle the response however you need to. - return response.text # type: ignore - - data = response.json() - - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=response, - ) - - @property - def headers(self) -> httpx.Headers: - return self.http_response.headers - - @property - def http_request(self) -> httpx.Request: - """Returns the httpx Request instance associated with the current response.""" - return self.http_response.request - - @property - def status_code(self) -> int: - return self.http_response.status_code - - @property - def url(self) -> httpx.URL: - """Returns the URL for which the request was made.""" - return self.http_response.url - - @property - def method(self) -> str: - return self.http_request.method - - @property - def http_version(self) -> str: - return self.http_response.http_version - - @property - def elapsed(self) -> datetime.timedelta: - """The time taken for the complete request/response cycle to complete.""" - return self.http_response.elapsed - - @property - def is_closed(self) -> bool: - """Whether or not the response body has been closed. - - If this is False then there is response data that has not been read yet. - You must either fully consume the response body or call `.close()` - before discarding the response to prevent resource leaks. - """ - return self.http_response.is_closed - - @override - def __repr__(self) -> str: - return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501 - - -class APIResponse(BaseAPIResponse[R]): - @property - def request_id(self) -> str | None: - return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] - - @overload - def parse(self, *, to: type[_T]) -> _T: ... - - @overload - def parse(self) -> R: ... - - def parse(self, *, to: type[_T] | None = None) -> R | _T: - """Returns the rich python representation of this response's data. - - For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. - - You can customize the type that the response is parsed into through - the `to` argument, e.g. - - ```py - from openai import BaseModel - - - class MyModel(BaseModel): - foo: str - - - obj = response.parse(to=MyModel) - print(obj.foo) - ``` - - We support parsing: - - `BaseModel` - - `dict` - - `list` - - `Union` - - `str` - - `int` - - `float` - - `httpx.Response` - """ - cache_key = to if to is not None else self._cast_type - cached = self._parsed_by_type.get(cache_key) - if cached is not None: - return cached # type: ignore[no-any-return] - - if not self._is_sse_stream: - self.read() - - parsed = self._parse(to=to) - if is_given(self._options.post_parser): - parsed = self._options.post_parser(parsed) - - self._parsed_by_type[cache_key] = parsed - return parsed - - def read(self) -> bytes: - """Read and return the binary response content.""" - try: - return self.http_response.read() - except httpx.StreamConsumed as exc: - # The default error raised by httpx isn't very - # helpful in our case so we re-raise it with - # a different error message. - raise StreamAlreadyConsumed() from exc - - def text(self) -> str: - """Read and decode the response content into a string.""" - self.read() - return self.http_response.text - - def json(self) -> object: - """Read and decode the JSON response content.""" - self.read() - return self.http_response.json() - - def close(self) -> None: - """Close the response and release the connection. - - Automatically called if the response body is read to completion. - """ - self.http_response.close() - - def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: - """ - A byte-iterator over the decoded response content. - - This automatically handles gzip, deflate and brotli encoded responses. - """ - yield from self.http_response.iter_bytes(chunk_size) - - def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: - """A str-iterator over the decoded response content - that handles both gzip, deflate, etc but also detects the content's - string encoding. - """ - yield from self.http_response.iter_text(chunk_size) - - def iter_lines(self) -> Iterator[str]: - """Like `iter_text()` but will only yield chunks for each line""" - yield from self.http_response.iter_lines() - - -class MissingStreamClassError(TypeError): - def __init__(self) -> None: - super().__init__( - "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 - ) - - -class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818 - """ - Attempted to read or stream content, but the content has already - been streamed. - - This can happen if you use a method like `.iter_lines()` and then attempt - to read th entire response body afterwards, e.g. - - ```py - response = await client.post(...) - async for line in response.iter_lines(): - ... # do something with `line` - - content = await response.read() - # ^ error - ``` - - If you want this behavior you'll need to either manually accumulate the response - content or call `await response.read()` before iterating over the stream. - """ - - def __init__(self) -> None: - message = ( - "Attempted to read or stream some content, but the content has " - "already been streamed. " - "This could be due to attempting to stream the response " - "content more than once." - "\n\n" - "You can fix this by manually accumulating the response content while streaming " - "or by calling `.read()` before starting to stream." - ) - super().__init__(message) - - -def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: - """Given a type like `APIResponse[T]`, returns the generic type variable `T`. - - This also handles the case where a concrete subclass is given, e.g. - ```py - class MyResponse(APIResponse[bytes]): - ... - - extract_response_type(MyResponse) -> bytes - ``` - """ - return extract_type_var_from_base( - typ, - generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)), - index=0, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py deleted file mode 100644 index cbc449d24421d0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ /dev/null @@ -1,206 +0,0 @@ -from __future__ import annotations - -import inspect -import json -from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, cast - -import httpx - -from . import get_origin -from ._base_type import ResponseT -from ._errors import APIResponseError -from ._utils import extract_type_var_from_base, is_mapping - -_FIELD_SEPARATOR = ":" - -if TYPE_CHECKING: - from ._http_client import HttpClient - - -class StreamResponse(Generic[ResponseT]): - response: httpx.Response - _cast_type: type[ResponseT] - - def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, - ) -> None: - self.response = response - self._cast_type = cast_type - self._data_process_func = client._process_response_data - self._stream_chunks = self.__stream__() - - def __next__(self) -> ResponseT: - return self._stream_chunks.__next__() - - def __iter__(self) -> Iterator[ResponseT]: - yield from self._stream_chunks - - def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() - iterator = sse_line_parser.iter_lines(self.response.iter_lines()) - - for sse in iterator: - if sse.data.startswith("[DONE]"): - break - - if sse.event is None: - data = sse.json_data() - if isinstance(data, Mapping) and data.get("error"): - raise APIResponseError( - message="An error occurred during streaming", - request=self.response.request, - json_data=data["error"], - ) - if sse.event is None: - data = sse.json_data() - if is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIResponseError( - message=message, - request=self.response.request, - json_data=data["error"], - ) - yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) - - else: - data = sse.json_data() - - if sse.event == "error" and is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIResponseError( - message=message, - request=self.response.request, - json_data=data["error"], - ) - yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) - - for sse in iterator: - pass - - -class Event: - def __init__( - self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None - ): - self._event = event - self._data = data - self._id = id - self._retry = retry - - def __repr__(self): - data_len = len(self._data) if self._data else 0 - return ( - f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" - ) - - @property - def event(self): - return self._event - - @property - def data(self): - return self._data - - def json_data(self): - return json.loads(self._data) - - @property - def id(self): - return self._id - - @property - def retry(self): - return self._retry - - -class SSELineParser: - _data: list[str] - _event: str | None - _retry: int | None - _id: str | None - - def __init__(self): - self._event = None - self._data = [] - self._id = None - self._retry = None - - def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: - for line in lines: - line = line.rstrip("\n") - if not line: - if self._event is None and not self._data and self._id is None and self._retry is None: - continue - sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) - self._event = None - self._data = [] - self._id = None - self._retry = None - - yield sse_event - self.decode_line(line) - - def decode_line(self, line: str): - if line.startswith(":") or not line: - return - - field, _p, value = line.partition(":") - - value = value.removeprefix(" ") - if field == "data": - self._data.append(value) - elif field == "event": - self._event = value - elif field == "retry": - try: - self._retry = int(value) - except (TypeError, ValueError): - pass - return - - -def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]: - """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" - origin = get_origin(typ) or typ - return inspect.isclass(origin) and issubclass(origin, StreamResponse) - - -def extract_stream_chunk_type( - stream_cls: type, - *, - failure_message: str | None = None, -) -> type: - """Given a type like `StreamResponse[T]`, returns the generic type variable `T`. - - This also handles the case where a concrete subclass is given, e.g. - ```py - class MyStream(StreamResponse[bytes]): - ... - - extract_stream_chunk_type(MyStream) -> bytes - ``` - """ - - return extract_type_var_from_base( - stream_cls, - index=0, - generic_bases=cast("tuple[type, ...]", (StreamResponse,)), - failure_message=failure_message, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py deleted file mode 100644 index a66b095816b8b0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -from ._utils import ( # noqa: I001 - remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414 - flatten as flatten, # noqa: PLC0414 - is_dict as is_dict, # noqa: PLC0414 - is_list as is_list, # noqa: PLC0414 - is_given as is_given, # noqa: PLC0414 - is_tuple as is_tuple, # noqa: PLC0414 - is_mapping as is_mapping, # noqa: PLC0414 - is_tuple_t as is_tuple_t, # noqa: PLC0414 - parse_date as parse_date, # noqa: PLC0414 - is_iterable as is_iterable, # noqa: PLC0414 - is_sequence as is_sequence, # noqa: PLC0414 - coerce_float as coerce_float, # noqa: PLC0414 - is_mapping_t as is_mapping_t, # noqa: PLC0414 - removeprefix as removeprefix, # noqa: PLC0414 - removesuffix as removesuffix, # noqa: PLC0414 - extract_files as extract_files, # noqa: PLC0414 - is_sequence_t as is_sequence_t, # noqa: PLC0414 - required_args as required_args, # noqa: PLC0414 - coerce_boolean as coerce_boolean, # noqa: PLC0414 - coerce_integer as coerce_integer, # noqa: PLC0414 - file_from_path as file_from_path, # noqa: PLC0414 - parse_datetime as parse_datetime, # noqa: PLC0414 - strip_not_given as strip_not_given, # noqa: PLC0414 - deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414 - get_async_library as get_async_library, # noqa: PLC0414 - maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414 - get_required_header as get_required_header, # noqa: PLC0414 - maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414 - maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414 - drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414 -) - - -from ._typing import ( - is_list_type as is_list_type, # noqa: PLC0414 - is_union_type as is_union_type, # noqa: PLC0414 - extract_type_arg as extract_type_arg, # noqa: PLC0414 - is_iterable_type as is_iterable_type, # noqa: PLC0414 - is_required_type as is_required_type, # noqa: PLC0414 - is_annotated_type as is_annotated_type, # noqa: PLC0414 - strip_annotated_type as strip_annotated_type, # noqa: PLC0414 - extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414 -) - -from ._transform import ( - PropertyInfo as PropertyInfo, # noqa: PLC0414 - transform as transform, # noqa: PLC0414 - async_transform as async_transform, # noqa: PLC0414 - maybe_transform as maybe_transform, # noqa: PLC0414 - async_maybe_transform as async_maybe_transform, # noqa: PLC0414 -) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py deleted file mode 100644 index e8ef1f79358a96..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py +++ /dev/null @@ -1,383 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import pathlib -from collections.abc import Mapping -from datetime import date, datetime -from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints - -import anyio -import pydantic -from typing_extensions import override - -from .._base_compat import is_typeddict, model_dump -from .._files import is_base64_file_input -from ._typing import ( - extract_type_arg, - is_annotated_type, - is_iterable_type, - is_list_type, - is_required_type, - is_union_type, - strip_annotated_type, -) -from ._utils import ( - is_iterable, - is_list, - is_mapping, -) - -_T = TypeVar("_T") - - -# TODO: support for drilling globals() and locals() -# TODO: ensure works correctly with forward references in all cases - - -PropertyFormat = Literal["iso8601", "base64", "custom"] - - -class PropertyInfo: - """Metadata class to be used in Annotated types to provide information about a given type. - - For example: - - class MyParams(TypedDict): - account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')] - - This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API. - """ # noqa: E501 - - alias: str | None - format: PropertyFormat | None - format_template: str | None - discriminator: str | None - - def __init__( - self, - *, - alias: str | None = None, - format: PropertyFormat | None = None, - format_template: str | None = None, - discriminator: str | None = None, - ) -> None: - self.alias = alias - self.format = format - self.format_template = format_template - self.discriminator = discriminator - - @override - def __repr__(self) -> str: - return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501 - - -def maybe_transform( - data: object, - expected_type: object, -) -> Any | None: - """Wrapper over `transform()` that allows `None` to be passed. - - See `transform()` for more details. - """ - if data is None: - return None - return transform(data, expected_type) - - -# Wrapper over _transform_recursive providing fake types -def transform( - data: _T, - expected_type: object, -) -> _T: - """Transform dictionaries based off of type information from the given type, for example: - - ```py - class Params(TypedDict, total=False): - card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] - - - transformed = transform({"card_id": ""}, Params) - # {'cardID': ''} - ``` - - Any keys / data that does not have type information given will be included as is. - - It should be noted that the transformations that this function does are not represented in the type system. - """ - transformed = _transform_recursive(data, annotation=cast(type, expected_type)) - return cast(_T, transformed) - - -def _get_annotated_type(type_: type) -> type | None: - """If the given type is an `Annotated` type then it is returned, if not `None` is returned. - - This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]` - """ - if is_required_type(type_): - # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]` - type_ = get_args(type_)[0] - - if is_annotated_type(type_): - return type_ - - return None - - -def _maybe_transform_key(key: str, type_: type) -> str: - """Transform the given `data` based on the annotations provided in `type_`. - - Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. - """ - annotated_type = _get_annotated_type(type_) - if annotated_type is None: - # no `Annotated` definition for this type, no transformation needed - return key - - # ignore the first argument as it is the actual type - annotations = get_args(annotated_type)[1:] - for annotation in annotations: - if isinstance(annotation, PropertyInfo) and annotation.alias is not None: - return annotation.alias - - return key - - -def _transform_recursive( - data: object, - *, - annotation: type, - inner_type: type | None = None, -) -> object: - """Transform the given data against the expected type. - - Args: - annotation: The direct type annotation given to the particular piece of data. - This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc - - inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type - is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in - the list can be transformed using the metadata from the container type. - - Defaults to the same value as the `annotation` argument. - """ - if inner_type is None: - inner_type = annotation - - stripped_type = strip_annotated_type(inner_type) - if is_typeddict(stripped_type) and is_mapping(data): - return _transform_typeddict(data, stripped_type) - - if ( - # List[T] - (is_list_type(stripped_type) and is_list(data)) - # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) - ): - inner_type = extract_type_arg(stripped_type, 0) - return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] - - if is_union_type(stripped_type): - # For union types we run the transformation against all subtypes to ensure that everything is transformed. - # - # TODO: there may be edge cases where the same normalized field name will transform to two different names - # in different subtypes. - for subtype in get_args(stripped_type): - data = _transform_recursive(data, annotation=annotation, inner_type=subtype) - return data - - if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) - - annotated_type = _get_annotated_type(annotation) - if annotated_type is None: - return data - - # ignore the first argument as it is the actual type - annotations = get_args(annotated_type)[1:] - for annotation in annotations: - if isinstance(annotation, PropertyInfo) and annotation.format is not None: - return _format_data(data, annotation.format, annotation.format_template) - - return data - - -def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: - if isinstance(data, date | datetime): - if format_ == "iso8601": - return data.isoformat() - - if format_ == "custom" and format_template is not None: - return data.strftime(format_template) - - if format_ == "base64" and is_base64_file_input(data): - binary: str | bytes | None = None - - if isinstance(data, pathlib.Path): - binary = data.read_bytes() - elif isinstance(data, io.IOBase): - binary = data.read() - - if isinstance(binary, str): # type: ignore[unreachable] - binary = binary.encode() - - if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") - - return base64.b64encode(binary).decode("ascii") - - return data - - -def _transform_typeddict( - data: Mapping[str, object], - expected_type: type, -) -> Mapping[str, object]: - result: dict[str, object] = {} - annotations = get_type_hints(expected_type, include_extras=True) - for key, value in data.items(): - type_ = annotations.get(key) - if type_ is None: - # we do not have a type annotation for this field, leave it as is - result[key] = value - else: - result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) - return result - - -async def async_maybe_transform( - data: object, - expected_type: object, -) -> Any | None: - """Wrapper over `async_transform()` that allows `None` to be passed. - - See `async_transform()` for more details. - """ - if data is None: - return None - return await async_transform(data, expected_type) - - -async def async_transform( - data: _T, - expected_type: object, -) -> _T: - """Transform dictionaries based off of type information from the given type, for example: - - ```py - class Params(TypedDict, total=False): - card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] - - - transformed = transform({"card_id": ""}, Params) - # {'cardID': ''} - ``` - - Any keys / data that does not have type information given will be included as is. - - It should be noted that the transformations that this function does are not represented in the type system. - """ - transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) - return cast(_T, transformed) - - -async def _async_transform_recursive( - data: object, - *, - annotation: type, - inner_type: type | None = None, -) -> object: - """Transform the given data against the expected type. - - Args: - annotation: The direct type annotation given to the particular piece of data. - This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc - - inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type - is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in - the list can be transformed using the metadata from the container type. - - Defaults to the same value as the `annotation` argument. - """ - if inner_type is None: - inner_type = annotation - - stripped_type = strip_annotated_type(inner_type) - if is_typeddict(stripped_type) and is_mapping(data): - return await _async_transform_typeddict(data, stripped_type) - - if ( - # List[T] - (is_list_type(stripped_type) and is_list(data)) - # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) - ): - inner_type = extract_type_arg(stripped_type, 0) - return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] - - if is_union_type(stripped_type): - # For union types we run the transformation against all subtypes to ensure that everything is transformed. - # - # TODO: there may be edge cases where the same normalized field name will transform to two different names - # in different subtypes. - for subtype in get_args(stripped_type): - data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) - return data - - if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) - - annotated_type = _get_annotated_type(annotation) - if annotated_type is None: - return data - - # ignore the first argument as it is the actual type - annotations = get_args(annotated_type)[1:] - for annotation in annotations: - if isinstance(annotation, PropertyInfo) and annotation.format is not None: - return await _async_format_data(data, annotation.format, annotation.format_template) - - return data - - -async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: - if isinstance(data, date | datetime): - if format_ == "iso8601": - return data.isoformat() - - if format_ == "custom" and format_template is not None: - return data.strftime(format_template) - - if format_ == "base64" and is_base64_file_input(data): - binary: str | bytes | None = None - - if isinstance(data, pathlib.Path): - binary = await anyio.Path(data).read_bytes() - elif isinstance(data, io.IOBase): - binary = data.read() - - if isinstance(binary, str): # type: ignore[unreachable] - binary = binary.encode() - - if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") - - return base64.b64encode(binary).decode("ascii") - - return data - - -async def _async_transform_typeddict( - data: Mapping[str, object], - expected_type: type, -) -> Mapping[str, object]: - result: dict[str, object] = {} - annotations = get_type_hints(expected_type, include_extras=True) - for key, value in data.items(): - type_ = annotations.get(key) - if type_ is None: - # we do not have a type annotation for this field, leave it as is - result[key] = value - else: - result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) - return result diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py deleted file mode 100644 index c7c54dcc37458d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -from collections import abc as _c_abc -from collections.abc import Iterable -from typing import Annotated, Any, TypeVar, cast, get_args, get_origin - -from typing_extensions import Required - -from .._base_compat import is_union as _is_union -from .._base_type import InheritsGeneric - - -def is_annotated_type(typ: type) -> bool: - return get_origin(typ) == Annotated - - -def is_list_type(typ: type) -> bool: - return (get_origin(typ) or typ) == list - - -def is_iterable_type(typ: type) -> bool: - """If the given type is `typing.Iterable[T]`""" - origin = get_origin(typ) or typ - return origin in {Iterable, _c_abc.Iterable} - - -def is_union_type(typ: type) -> bool: - return _is_union(get_origin(typ)) - - -def is_required_type(typ: type) -> bool: - return get_origin(typ) == Required - - -def is_typevar(typ: type) -> bool: - # type ignore is required because type checkers - # think this expression will always return False - return type(typ) == TypeVar # type: ignore - - -# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] -def strip_annotated_type(typ: type) -> type: - if is_required_type(typ) or is_annotated_type(typ): - return strip_annotated_type(cast(type, get_args(typ)[0])) - - return typ - - -def extract_type_arg(typ: type, index: int) -> type: - args = get_args(typ) - try: - return cast(type, args[index]) - except IndexError as err: - raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err - - -def extract_type_var_from_base( - typ: type, - *, - generic_bases: tuple[type, ...], - index: int, - failure_message: str | None = None, -) -> type: - """Given a type like `Foo[T]`, returns the generic type variable `T`. - - This also handles the case where a concrete subclass is given, e.g. - ```py - class MyResponse(Foo[bytes]): - ... - - extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes - ``` - - And where a generic subclass is given: - ```py - _T = TypeVar('_T') - class MyResponse(Foo[_T]): - ... - - extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes - ``` - """ - cls = cast(object, get_origin(typ) or typ) - if cls in generic_bases: - # we're given the class directly - return extract_type_arg(typ, index) - - # if a subclass is given - # --- - # this is needed as __orig_bases__ is not present in the typeshed stubs - # because it is intended to be for internal use only, however there does - # not seem to be a way to resolve generic TypeVars for inherited subclasses - # without using it. - if isinstance(cls, InheritsGeneric): - target_base_class: Any | None = None - for base in cls.__orig_bases__: - if base.__origin__ in generic_bases: - target_base_class = base - break - - if target_base_class is None: - raise RuntimeError( - "Could not find the generic base class;\n" - "This should never happen;\n" - f"Does {cls} inherit from one of {generic_bases} ?" - ) - - extracted = extract_type_arg(target_base_class, index) - if is_typevar(extracted): - # If the extracted type argument is itself a type variable - # then that means the subclass itself is generic, so we have - # to resolve the type argument from the class itself, not - # the base class. - # - # Note: if there is more than 1 type argument, the subclass could - # change the ordering of the type arguments, this is not currently - # supported. - return extract_type_arg(typ, index) - - return extracted - - raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py deleted file mode 100644 index 3a7b234ab0c067..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py +++ /dev/null @@ -1,409 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import os -import re -from collections.abc import Callable, Iterable, Mapping, Sequence -from pathlib import Path -from typing import ( - Any, - TypeGuard, - TypeVar, - Union, - cast, - overload, -) - -import sniffio - -from .._base_compat import parse_date as parse_date # noqa: PLC0414 -from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414 -from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr - - -def remove_notgiven_indict(obj): - if obj is None or (not isinstance(obj, Mapping)): - return obj - return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} - - -_T = TypeVar("_T") -_TupleT = TypeVar("_TupleT", bound=tuple[object, ...]) -_MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) -_SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) -CallableT = TypeVar("CallableT", bound=Callable[..., Any]) - - -def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: - return [item for sublist in t for item in sublist] - - -def extract_files( - # TODO: this needs to take Dict but variance issues..... - # create protocol type ? - query: Mapping[str, object], - *, - paths: Sequence[Sequence[str]], -) -> list[tuple[str, FileTypes]]: - """Recursively extract files from the given dictionary based on specified paths. - - A path may look like this ['foo', 'files', '', 'data']. - - Note: this mutates the given dictionary. - """ - files: list[tuple[str, FileTypes]] = [] - for path in paths: - files.extend(_extract_items(query, path, index=0, flattened_key=None)) - return files - - -def _extract_items( - obj: object, - path: Sequence[str], - *, - index: int, - flattened_key: str | None, -) -> list[tuple[str, FileTypes]]: - try: - key = path[index] - except IndexError: - if isinstance(obj, NotGiven): - # no value was provided - we can safely ignore - return [] - - # cyclical import - from .._files import assert_is_file_content - - # We have exhausted the path, return the entry we found. - assert_is_file_content(obj, key=flattened_key) - assert flattened_key is not None - return [(flattened_key, cast(FileTypes, obj))] - - index += 1 - if is_dict(obj): - try: - # We are at the last entry in the path so we must remove the field - if (len(path)) == index: - item = obj.pop(key) - else: - item = obj[key] - except KeyError: - # Key was not present in the dictionary, this is not indicative of an error - # as the given path may not point to a required field. We also do not want - # to enforce required fields as the API may differ from the spec in some cases. - return [] - if flattened_key is None: - flattened_key = key - else: - flattened_key += f"[{key}]" - return _extract_items( - item, - path, - index=index, - flattened_key=flattened_key, - ) - elif is_list(obj): - if key != "": - return [] - - return flatten( - [ - _extract_items( - item, - path, - index=index, - flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", - ) - for item in obj - ] - ) - - # Something unexpected was passed, just ignore it. - return [] - - -def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: - return not isinstance(obj, NotGiven) - - -# Type safe methods for narrowing types with TypeVars. -# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], -# however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. -# -# There are two separate functions defined, `is_*` and `is_*_t` for different use cases. -# `is_*` is for when you're dealing with an unknown input -# `is_*_t` is for when you're narrowing a known union type to a specific subset - - -def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: - return isinstance(obj, tuple) - - -def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: - return isinstance(obj, tuple) - - -def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: - return isinstance(obj, Sequence) - - -def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: - return isinstance(obj, Sequence) - - -def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: - return isinstance(obj, Mapping) - - -def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: - return isinstance(obj, Mapping) - - -def is_dict(obj: object) -> TypeGuard[dict[object, object]]: - return isinstance(obj, dict) - - -def is_list(obj: object) -> TypeGuard[list[object]]: - return isinstance(obj, list) - - -def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: - return isinstance(obj, Iterable) - - -def deepcopy_minimal(item: _T) -> _T: - """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: - - - mappings, e.g. `dict` - - list - - This is done for performance reasons. - """ - if is_mapping(item): - return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) - if is_list(item): - return cast(_T, [deepcopy_minimal(entry) for entry in item]) - return item - - -# copied from https://github.com/Rapptz/RoboDanny -def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: - size = len(seq) - if size == 0: - return "" - - if size == 1: - return seq[0] - - if size == 2: - return f"{seq[0]} {final} {seq[1]}" - - return delim.join(seq[:-1]) + f" {final} {seq[-1]}" - - -def quote(string: str) -> str: - """Add single quotation marks around the given string. Does *not* do any escaping.""" - return f"'{string}'" - - -def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: - """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. - - Useful for enforcing runtime validation of overloaded functions. - - Example usage: - ```py - @overload - def foo(*, a: str) -> str: - ... - - - @overload - def foo(*, b: bool) -> str: - ... - - - # This enforces the same constraints that a static type checker would - # i.e. that either a or b must be passed to the function - @required_args(["a"], ["b"]) - def foo(*, a: str | None = None, b: bool | None = None) -> str: - ... - ``` - """ - - def inner(func: CallableT) -> CallableT: - params = inspect.signature(func).parameters - positional = [ - name - for name, param in params.items() - if param.kind - in { - param.POSITIONAL_ONLY, - param.POSITIONAL_OR_KEYWORD, - } - ] - - @functools.wraps(func) - def wrapper(*args: object, **kwargs: object) -> object: - given_params: set[str] = set() - for i in range(len(args)): - try: - given_params.add(positional[i]) - except IndexError: - raise TypeError( - f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" - ) from None - - given_params.update(kwargs.keys()) - - for variant in variants: - matches = all(param in given_params for param in variant) - if matches: - break - else: # no break - if len(variants) > 1: - variations = human_join( - ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] - ) - msg = f"Missing required arguments; Expected either {variations} arguments to be given" - else: - # TODO: this error message is not deterministic - missing = list(set(variants[0]) - given_params) - if len(missing) > 1: - msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" - else: - msg = f"Missing required argument: {quote(missing[0])}" - raise TypeError(msg) - return func(*args, **kwargs) - - return wrapper # type: ignore - - return inner - - -_K = TypeVar("_K") -_V = TypeVar("_V") - - -@overload -def strip_not_given(obj: None) -> None: ... - - -@overload -def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... - - -@overload -def strip_not_given(obj: object) -> object: ... - - -def strip_not_given(obj: object | None) -> object: - """Remove all top-level keys where their values are instances of `NotGiven`""" - if obj is None: - return None - - if not is_mapping(obj): - return obj - - return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} - - -def coerce_integer(val: str) -> int: - return int(val, base=10) - - -def coerce_float(val: str) -> float: - return float(val) - - -def coerce_boolean(val: str) -> bool: - return val in {"true", "1", "on"} - - -def maybe_coerce_integer(val: str | None) -> int | None: - if val is None: - return None - return coerce_integer(val) - - -def maybe_coerce_float(val: str | None) -> float | None: - if val is None: - return None - return coerce_float(val) - - -def maybe_coerce_boolean(val: str | None) -> bool | None: - if val is None: - return None - return coerce_boolean(val) - - -def removeprefix(string: str, prefix: str) -> str: - """Remove a prefix from a string. - - Backport of `str.removeprefix` for Python < 3.9 - """ - if string.startswith(prefix): - return string[len(prefix) :] - return string - - -def removesuffix(string: str, suffix: str) -> str: - """Remove a suffix from a string. - - Backport of `str.removesuffix` for Python < 3.9 - """ - if string.endswith(suffix): - return string[: -len(suffix)] - return string - - -def file_from_path(path: str) -> FileTypes: - contents = Path(path).read_bytes() - file_name = os.path.basename(path) - return (file_name, contents) - - -def get_required_header(headers: HeadersLike, header: str) -> str: - lower_header = header.lower() - if isinstance(headers, Mapping): - headers = cast(Headers, headers) - for k, v in headers.items(): - if k.lower() == lower_header and isinstance(v, str): - return v - - """ to deal with the case where the header looks like Stainless-Event-Id """ - intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) - - for normalized_header in [header, lower_header, header.upper(), intercaps_header]: - value = headers.get(normalized_header) - if value: - return value - - raise ValueError(f"Could not find {header} header") - - -def get_async_library() -> str: - try: - return sniffio.current_async_library() - except Exception: - return "false" - - -def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]: - """ - 删除 ;base64, 前缀 - :param image_data: - :return: - """ - if isinstance(content, list): - for data in content: - if data.get("type") == "image_url": - image_data = data.get("image_url").get("url") - if image_data.startswith("data:image/"): - image_data = image_data.split("base64,")[-1] - data["image_url"]["url"] = image_data - - return content diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py deleted file mode 100644 index e5fce94c00e9e0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py +++ /dev/null @@ -1,78 +0,0 @@ -import logging -import os -import time - -logger = logging.getLogger(__name__) - - -class LoggerNameFilter(logging.Filter): - def filter(self, record): - # return record.name.startswith("loom_core") or record.name in "ERROR" or ( - # record.name.startswith("uvicorn.error") - # and record.getMessage().startswith("Uvicorn running on") - # ) - return True - - -def get_log_file(log_path: str, sub_dir: str): - """ - sub_dir should contain a timestamp. - """ - log_dir = os.path.join(log_path, sub_dir) - # Here should be creating a new directory each time, so `exist_ok=False` - os.makedirs(log_dir, exist_ok=False) - return os.path.join(log_dir, "zhipuai.log") - - -def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict: - # for windows, the path should be a raw string. - log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path - log_level = log_level.upper() - config_dict = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")}, - }, - "filters": { - "logger_name_filter": { - "()": __name__ + ".LoggerNameFilter", - }, - }, - "handlers": { - "stream_handler": { - "class": "logging.StreamHandler", - "formatter": "formatter", - "level": log_level, - # "stream": "ext://sys.stdout", - # "filters": ["logger_name_filter"], - }, - "file_handler": { - "class": "logging.handlers.RotatingFileHandler", - "formatter": "formatter", - "level": log_level, - "filename": log_file_path, - "mode": "a", - "maxBytes": log_max_bytes, - "backupCount": log_backup_count, - "encoding": "utf8", - }, - }, - "loggers": { - "loom_core": { - "handlers": ["stream_handler", "file_handler"], - "level": log_level, - "propagate": False, - } - }, - "root": { - "level": log_level, - "handlers": ["stream_handler", "file_handler"], - }, - } - return config_dict - - -def get_timestamp_ms(): - t = time.time() - return int(round(t * 1000)) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py deleted file mode 100644 index 7f0b1b91d98556..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py +++ /dev/null @@ -1,62 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Any, Generic, Optional, TypeVar, cast - -from typing_extensions import Protocol, override, runtime_checkable - -from ._http_client import BasePage, BaseSyncPage, PageInfo - -__all__ = ["SyncPage", "SyncCursorPage"] - -_T = TypeVar("_T") - - -@runtime_checkable -class CursorPageItem(Protocol): - id: Optional[str] - - -class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): - """Note: no pagination actually occurs yet, this is for forwards-compatibility.""" - - data: list[_T] - object: str - - @override - def _get_page_items(self) -> list[_T]: - data = self.data - if not data: - return [] - return data - - @override - def next_page_info(self) -> None: - """ - This page represents a response that isn't actually paginated at the API level - so there will never be a next page. - """ - return None - - -class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): - data: list[_T] - - @override - def _get_page_items(self) -> list[_T]: - data = self.data - if not data: - return [] - return data - - @override - def next_page_info(self) -> Optional[PageInfo]: - data = self.data - if not data: - return None - - item = cast(Any, data[-1]) - if not isinstance(item, CursorPageItem) or item.id is None: - # TODO emit warning log - return None - - return PageInfo(params={"after": item.id}) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py deleted file mode 100644 index 9f941fb91c8776..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .assistant_completion import AssistantCompletion - -__all__ = [ - "AssistantCompletion", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py deleted file mode 100644 index cbfb6edaeb1f19..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Any, Optional - -from ...core import BaseModel -from .message import MessageContent - -__all__ = ["AssistantCompletion", "CompletionUsage"] - - -class ErrorInfo(BaseModel): - code: str # 错误码 - message: str # 错误信息 - - -class AssistantChoice(BaseModel): - index: int # 结果下标 - delta: MessageContent # 当前会话输出消息体 - finish_reason: str - """ - # 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。 - # network_error 代表模型推理服务异常。 - """ # noqa: E501 - metadata: dict # 元信息,拓展字段 - - -class CompletionUsage(BaseModel): - prompt_tokens: int # 输入的 tokens 数量 - completion_tokens: int # 输出的 tokens 数量 - total_tokens: int # 总 tokens 数量 - - -class AssistantCompletion(BaseModel): - id: str # 请求 ID - conversation_id: str # 会话 ID - assistant_id: str # 智能体 ID - created: int # 请求创建时间,Unix 时间戳 - status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常 - last_error: Optional[ErrorInfo] # 异常信息 - choices: list[AssistantChoice] # 增量返回的信息 - metadata: Optional[dict[str, Any]] # 元信息,拓展字段 - usage: Optional[CompletionUsage] # tokens 数量统计 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py deleted file mode 100644 index 03f14f4238f37f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import TypedDict - - -class ConversationParameters(TypedDict, total=False): - assistant_id: str # 智能体 ID - page: int # 当前分页 - page_size: int # 分页数量 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py deleted file mode 100644 index d1833d220a2e3b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py +++ /dev/null @@ -1,29 +0,0 @@ -from ...core import BaseModel - -__all__ = ["ConversationUsageListResp"] - - -class Usage(BaseModel): - prompt_tokens: int # 用户输入的 tokens 数量 - completion_tokens: int # 模型输入的 tokens 数量 - total_tokens: int # 总 tokens 数量 - - -class ConversationUsage(BaseModel): - id: str # 会话 id - assistant_id: str # 智能体Assistant id - create_time: int # 创建时间 - update_time: int # 更新时间 - usage: Usage # 会话中 tokens 数量统计 - - -class ConversationUsageList(BaseModel): - assistant_id: str # 智能体id - has_more: bool # 是否还有更多页 - conversation_list: list[ConversationUsage] # 返回的 - - -class ConversationUsageListResp(BaseModel): - code: int - msg: str - data: ConversationUsageList diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py deleted file mode 100644 index 2def1025cd2b33..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Optional, TypedDict, Union - - -class AssistantAttachments: - file_id: str - - -class MessageTextContent: - type: str # 目前支持 type = text - text: str - - -MessageContent = Union[MessageTextContent] - - -class ConversationMessage(TypedDict): - """会话消息体""" - - role: str # 用户的输入角色,例如 'user' - content: list[MessageContent] # 会话消息体的内容 - - -class AssistantParameters(TypedDict, total=False): - """智能体参数类""" - - assistant_id: str # 智能体 ID - conversation_id: Optional[str] # 会话 ID,不传则创建新会话 - model: str # 模型名称,默认为 'GLM-4-Assistant' - stream: bool # 是否支持流式 SSE,需要传入 True - messages: list[ConversationMessage] # 会话消息体 - attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填 - metadata: Optional[dict] # 元信息,拓展字段,非必填 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py deleted file mode 100644 index 0709cdbcad25e1..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py +++ /dev/null @@ -1,21 +0,0 @@ -from ...core import BaseModel - -__all__ = ["AssistantSupportResp"] - - -class AssistantSupport(BaseModel): - assistant_id: str # 智能体的 Assistant id,用于智能体会话 - created_at: int # 创建时间 - updated_at: int # 更新时间 - name: str # 智能体名称 - avatar: str # 智能体头像 - description: str # 智能体描述 - status: str # 智能体状态,目前只有 publish - tools: list[str] # 智能体支持的工具名 - starter_prompts: list[str] # 智能体启动推荐的 prompt - - -class AssistantSupportResp(BaseModel): - code: int - msg: str - data: list[AssistantSupport] # 智能体列表 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py deleted file mode 100644 index 562e0151e53b48..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .message_content import MessageContent - -__all__ = ["MessageContent"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py deleted file mode 100644 index 6a1a438a6fe03d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Annotated, TypeAlias, Union - -from ....core._utils import PropertyInfo -from .text_content_block import TextContentBlock -from .tools_delta_block import ToolsDeltaBlock - -__all__ = ["MessageContent"] - - -MessageContent: TypeAlias = Annotated[ - Union[ToolsDeltaBlock, TextContentBlock], - PropertyInfo(discriminator="type"), -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py deleted file mode 100644 index 865fb1139e2f75..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Literal - -from ....core import BaseModel - -__all__ = ["TextContentBlock"] - - -class TextContentBlock(BaseModel): - content: str - - role: str = "assistant" - - type: Literal["content"] = "content" - """Always `content`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py deleted file mode 100644 index 9d569b282ef9f7..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Literal - -__all__ = ["CodeInterpreterToolBlock"] - -from .....core import BaseModel - - -class CodeInterpreterToolOutput(BaseModel): - """代码工具输出结果""" - - type: str # 代码执行日志,目前只有 logs - logs: str # 代码执行的日志结果 - error_msg: str # 错误信息 - - -class CodeInterpreter(BaseModel): - """代码解释器""" - - input: str # 生成的代码片段,输入给代码沙盒 - outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果 - - -class CodeInterpreterToolBlock(BaseModel): - """代码工具块""" - - code_interpreter: CodeInterpreter # 代码解释器对象 - type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter` diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py deleted file mode 100644 index 0b6895556b6164..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Literal - -from .....core import BaseModel - -__all__ = ["DrawingToolBlock"] - - -class DrawingToolOutput(BaseModel): - image: str - - -class DrawingTool(BaseModel): - input: str - outputs: list[DrawingToolOutput] - - -class DrawingToolBlock(BaseModel): - drawing_tool: DrawingTool - - type: Literal["drawing_tool"] - """Always `drawing_tool`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py deleted file mode 100644 index c439bc4b3fbbb8..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Literal, Union - -__all__ = ["FunctionToolBlock"] - -from .....core import BaseModel - - -class FunctionToolOutput(BaseModel): - content: str - - -class FunctionTool(BaseModel): - name: str - arguments: Union[str, dict] - outputs: list[FunctionToolOutput] - - -class FunctionToolBlock(BaseModel): - function: FunctionTool - - type: Literal["function"] - """Always `drawing_tool`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py deleted file mode 100644 index 4789e9378a8a39..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Literal - -from .....core import BaseModel - - -class RetrievalToolOutput(BaseModel): - """ - This class represents the output of a retrieval tool. - - Attributes: - - text (str): The text snippet retrieved from the knowledge base. - - document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration. - """ # noqa: E501 - - text: str - document: str - - -class RetrievalTool(BaseModel): - """ - This class represents the outputs of a retrieval tool. - - Attributes: - - outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base. - """ # noqa: E501 - - outputs: list[RetrievalToolOutput] - - -class RetrievalToolBlock(BaseModel): - """ - This class represents a block for invoking the retrieval tool. - - Attributes: - - retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs. - - type (Literal["retrieval"]): The type of tool being used, always set to "retrieval". - """ - - retrieval: RetrievalTool - type: Literal["retrieval"] - """Always `retrieval`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py deleted file mode 100644 index 98544053d4c83a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Annotated, TypeAlias, Union - -from .....core._utils import PropertyInfo -from .code_interpreter_delta_block import CodeInterpreterToolBlock -from .drawing_tool_delta_block import DrawingToolBlock -from .function_delta_block import FunctionToolBlock -from .retrieval_delta_black import RetrievalToolBlock -from .web_browser_delta_block import WebBrowserToolBlock - -__all__ = ["ToolsType"] - - -ToolsType: TypeAlias = Annotated[ - Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock], - PropertyInfo(discriminator="type"), -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py deleted file mode 100644 index 966e6fe0c84fef..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Literal - -from .....core import BaseModel - -__all__ = ["WebBrowserToolBlock"] - - -class WebBrowserOutput(BaseModel): - """ - This class represents the output of a web browser search result. - - Attributes: - - title (str): The title of the search result. - - link (str): The URL link to the search result's webpage. - - content (str): The textual content extracted from the search result. - - error_msg (str): Any error message encountered during the search or retrieval process. - """ - - title: str - link: str - content: str - error_msg: str - - -class WebBrowser(BaseModel): - """ - This class represents the input and outputs of a web browser search. - - Attributes: - - input (str): The input query for the web browser search. - - outputs (List[WebBrowserOutput]): A list of search results returned by the web browser. - """ - - input: str - outputs: list[WebBrowserOutput] - - -class WebBrowserToolBlock(BaseModel): - """ - This class represents a block for invoking the web browser tool. - - Attributes: - - web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs. - - type (Literal["web_browser"]): The type of tool being used, always set to "web_browser". - """ - - web_browser: WebBrowser - type: Literal["web_browser"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py deleted file mode 100644 index 781a1ab819c286..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Literal - -from ....core import BaseModel -from .tools.tools_type import ToolsType - -__all__ = ["ToolsDeltaBlock"] - - -class ToolsDeltaBlock(BaseModel): - tool_calls: list[ToolsType] - """The index of the content part in the message.""" - - role: str = "tool" - - type: Literal["tool_calls"] = "tool_calls" - """Always `tool_calls`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py deleted file mode 100644 index 560562915c9d32..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py +++ /dev/null @@ -1,82 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -import builtins -from typing import Literal, Optional - -from ..core import BaseModel -from .batch_error import BatchError -from .batch_request_counts import BatchRequestCounts - -__all__ = ["Batch", "Errors"] - - -class Errors(BaseModel): - data: Optional[list[BatchError]] = None - - object: Optional[str] = None - """这个类型,一直是`list`。""" - - -class Batch(BaseModel): - id: str - - completion_window: str - """用于执行请求的地址信息。""" - - created_at: int - """这是 Unix timestamp (in seconds) 表示的创建时间。""" - - endpoint: str - """这是ZhipuAI endpoint的地址。""" - - input_file_id: str - """标记为batch的输入文件的ID。""" - - object: Literal["batch"] - """这个类型,一直是`batch`.""" - - status: Literal[ - "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled" - ] - """batch 的状态。""" - - cancelled_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的取消时间。""" - - cancelling_at: Optional[int] = None - """Unix timestamp (in seconds) 表示发起取消的请求时间 """ - - completed_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的完成时间。""" - - error_file_id: Optional[str] = None - """这个文件id包含了执行请求失败的请求的输出。""" - - errors: Optional[Errors] = None - - expired_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的将在过期时间。""" - - expires_at: Optional[int] = None - """Unix timestamp (in seconds) 触发过期""" - - failed_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的失败时间。""" - - finalizing_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的最终时间。""" - - in_progress_at: Optional[int] = None - """Unix timestamp (in seconds) 表示的开始处理时间。""" - - metadata: Optional[builtins.object] = None - """ - key:value形式的元数据,以便将信息存储 - 结构化格式。键的长度是64个字符,值最长512个字符 - """ - - output_file_id: Optional[str] = None - """完成请求的输出文件的ID。""" - - request_counts: Optional[BatchRequestCounts] = None - """批次中不同状态的请求计数""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py deleted file mode 100644 index 3dae65ea46fcbe..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -from typing_extensions import Required, TypedDict - -__all__ = ["BatchCreateParams"] - - -class BatchCreateParams(TypedDict, total=False): - completion_window: Required[str] - """The time frame within which the batch should be processed. - - Currently only `24h` is supported. - """ - - endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]] - """The endpoint to be used for all requests in the batch. - - Currently `/v1/chat/completions` and `/v1/embeddings` are supported. - """ - - input_file_id: Required[str] - """The ID of an uploaded file that contains requests for the new batch. - - See [upload file](https://platform.openai.com/docs/api-reference/files/create) - for how to upload a file. - - Your input file must be formatted as a - [JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput), - and must be uploaded with the purpose `batch`. - """ - - metadata: Optional[dict[str, str]] - """Optional custom metadata for the batch.""" - - auto_delete_input_file: Optional[bool] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py deleted file mode 100644 index f934db19781e41..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py +++ /dev/null @@ -1,21 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Optional - -from ..core import BaseModel - -__all__ = ["BatchError"] - - -class BatchError(BaseModel): - code: Optional[str] = None - """定义的业务错误码""" - - line: Optional[int] = None - """文件中的行号""" - - message: Optional[str] = None - """关于对话文件中的错误的描述""" - - param: Optional[str] = None - """参数名称,如果有的话""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py deleted file mode 100644 index 1a681671320eca..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing_extensions import TypedDict - -__all__ = ["BatchListParams"] - - -class BatchListParams(TypedDict, total=False): - after: str - """分页的游标,用于获取下一页的数据。 - - `after` 是一个指向当前页面的游标,用于获取下一页的数据。如果没有提供 `after`,则返回第一页的数据。 - list. - """ - - limit: int - """这个参数用于限制返回的结果数量。 - - Limit 用于限制返回的结果数量。默认值为 10 - """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py deleted file mode 100644 index ca3ccae625052b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py +++ /dev/null @@ -1,14 +0,0 @@ -from ..core import BaseModel - -__all__ = ["BatchRequestCounts"] - - -class BatchRequestCounts(BaseModel): - completed: int - """这个数字表示已经完成的请求。""" - - failed: int - """这个数字表示失败的请求。""" - - total: int - """这个数字表示总的请求。""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py deleted file mode 100644 index c1eed070f32d9f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Optional - -from ...core import BaseModel -from .chat_completion import CompletionChoice, CompletionUsage - -__all__ = ["AsyncTaskStatus", "AsyncCompletion"] - - -class AsyncTaskStatus(BaseModel): - id: Optional[str] = None - request_id: Optional[str] = None - model: Optional[str] = None - task_status: Optional[str] = None - - -class AsyncCompletion(BaseModel): - id: Optional[str] = None - request_id: Optional[str] = None - model: Optional[str] = None - task_status: str - choices: list[CompletionChoice] - usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py deleted file mode 100644 index 1945a826cda2d0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = ["Completion", "CompletionUsage"] - - -class Function(BaseModel): - arguments: str - name: str - - -class CompletionMessageToolCall(BaseModel): - id: str - function: Function - type: str - - -class CompletionMessage(BaseModel): - content: Optional[str] = None - role: str - tool_calls: Optional[list[CompletionMessageToolCall]] = None - - -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class CompletionChoice(BaseModel): - index: int - finish_reason: str - message: CompletionMessage - - -class Completion(BaseModel): - model: Optional[str] = None - created: Optional[int] = None - choices: list[CompletionChoice] - request_id: Optional[str] = None - id: Optional[str] = None - usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py deleted file mode 100644 index 27fad0008a1dd4..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, Optional - -from ...core import BaseModel - -__all__ = [ - "CompletionUsage", - "ChatCompletionChunk", - "Choice", - "ChoiceDelta", - "ChoiceDeltaFunctionCall", - "ChoiceDeltaToolCall", - "ChoiceDeltaToolCallFunction", -] - - -class ChoiceDeltaFunctionCall(BaseModel): - arguments: Optional[str] = None - name: Optional[str] = None - - -class ChoiceDeltaToolCallFunction(BaseModel): - arguments: Optional[str] = None - name: Optional[str] = None - - -class ChoiceDeltaToolCall(BaseModel): - index: int - id: Optional[str] = None - function: Optional[ChoiceDeltaToolCallFunction] = None - type: Optional[str] = None - - -class ChoiceDelta(BaseModel): - content: Optional[str] = None - role: Optional[str] = None - tool_calls: Optional[list[ChoiceDeltaToolCall]] = None - - -class Choice(BaseModel): - delta: ChoiceDelta - finish_reason: Optional[str] = None - index: int - - -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class ChatCompletionChunk(BaseModel): - id: Optional[str] = None - choices: list[Choice] - created: Optional[int] = None - model: Optional[str] = None - usage: Optional[CompletionUsage] = None - extra_json: dict[str, Any] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py deleted file mode 100644 index 6ee4dc4794b201..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Optional - -from typing_extensions import TypedDict - - -class Reference(TypedDict, total=False): - enable: Optional[bool] - search_query: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py deleted file mode 100644 index 666b38855cd637..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import Literal, Optional - -from typing_extensions import Required, TypedDict - -__all__ = [ - "CodeGeexTarget", - "CodeGeexContext", - "CodeGeexExtra", -] - - -class CodeGeexTarget(TypedDict, total=False): - """补全的内容参数""" - - path: Optional[str] - """文件路径""" - language: Required[ - Literal[ - "c", - "c++", - "cpp", - "c#", - "csharp", - "c-sharp", - "css", - "cuda", - "dart", - "lua", - "objectivec", - "objective-c", - "objective-c++", - "python", - "perl", - "prolog", - "swift", - "lisp", - "java", - "scala", - "tex", - "jsx", - "tsx", - "vue", - "markdown", - "html", - "php", - "js", - "javascript", - "typescript", - "go", - "shell", - "rust", - "sql", - "kotlin", - "vb", - "ruby", - "pascal", - "r", - "fortran", - "lean", - "matlab", - "delphi", - "scheme", - "basic", - "assembly", - "groovy", - "abap", - "gdscript", - "haskell", - "julia", - "elixir", - "excel", - "clojure", - "actionscript", - "solidity", - "powershell", - "erlang", - "cobol", - "alloy", - "awk", - "thrift", - "sparql", - "augeas", - "cmake", - "f-sharp", - "stan", - "isabelle", - "dockerfile", - "rmarkdown", - "literate-agda", - "tcl", - "glsl", - "antlr", - "verilog", - "racket", - "standard-ml", - "elm", - "yaml", - "smalltalk", - "ocaml", - "idris", - "visual-basic", - "protocol-buffer", - "bluespec", - "applescript", - "makefile", - "tcsh", - "maple", - "systemverilog", - "literate-coffeescript", - "vhdl", - "restructuredtext", - "sas", - "literate-haskell", - "java-server-pages", - "coffeescript", - "emacs-lisp", - "mathematica", - "xslt", - "zig", - "common-lisp", - "stata", - "agda", - "ada", - ] - ] - """代码语言类型,如python""" - code_prefix: Required[str] - """补全位置的前文""" - code_suffix: Required[str] - """补全位置的后文""" - - -class CodeGeexContext(TypedDict, total=False): - """附加代码""" - - path: Required[str] - """附加代码文件的路径""" - code: Required[str] - """附加的代码内容""" - - -class CodeGeexExtra(TypedDict, total=False): - target: Required[CodeGeexTarget] - """补全的内容参数""" - contexts: Optional[list[CodeGeexContext]] - """附加代码""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py deleted file mode 100644 index 8425b5c86688dd..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from ..core import BaseModel -from .chat.chat_completion import CompletionUsage - -__all__ = ["Embedding", "EmbeddingsResponded"] - - -class Embedding(BaseModel): - object: str - index: Optional[int] = None - embedding: list[float] - - -class EmbeddingsResponded(BaseModel): - object: str - data: list[Embedding] - model: str - usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py deleted file mode 100644 index bbaf59e4d7d17a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .file_deleted import FileDeleted -from .file_object import FileObject, ListOfFileObject -from .upload_detail import UploadDetail - -__all__ = ["FileObject", "ListOfFileObject", "UploadDetail", "FileDeleted"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py deleted file mode 100644 index 4ef93b1c05acae..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -from typing_extensions import Required, TypedDict - -__all__ = ["FileCreateParams"] - -from ...core import FileTypes -from . import UploadDetail - - -class FileCreateParams(TypedDict, total=False): - file: FileTypes - """file和 upload_detail二选一必填""" - - upload_detail: list[UploadDetail] - """file和 upload_detail二选一必填""" - - purpose: Required[Literal["fine-tune", "retrieval", "batch"]] - """ - 上传文件的用途,支持 "fine-tune和 "retrieval" - retrieval支持上传Doc、Docx、PDF、Xlsx、URL类型文件,且单个文件的大小不超过 5MB。 - fine-tune支持上传.jsonl文件且当前单个文件的大小最大可为 100 MB ,文件中语料格式需满足微调指南中所描述的格式。 - """ - custom_separator: Optional[list[str]] - """ - 当 purpose 为 retrieval 且文件类型为 pdf, url, docx 时上传,切片规则默认为 `\n`。 - """ - knowledge_id: str - """ - 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 - """ - - sentence_size: int - """ - 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 - """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py deleted file mode 100644 index a384b1a69a5735..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Literal - -from ...core import BaseModel - -__all__ = ["FileDeleted"] - - -class FileDeleted(BaseModel): - id: str - - deleted: bool - - object: Literal["file"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py deleted file mode 100644 index 8f9d0fbb8e6ce3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = ["FileObject", "ListOfFileObject"] - - -class FileObject(BaseModel): - id: Optional[str] = None - bytes: Optional[int] = None - created_at: Optional[int] = None - filename: Optional[str] = None - object: Optional[str] = None - purpose: Optional[str] = None - status: Optional[str] = None - status_details: Optional[str] = None - - -class ListOfFileObject(BaseModel): - object: Optional[str] = None - data: list[FileObject] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py deleted file mode 100644 index 8f1ca5ce5756aa..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - - -class UploadDetail(BaseModel): - url: str - knowledge_type: int - file_name: Optional[str] = None - sentence_size: Optional[int] = None - custom_separator: Optional[list[str]] = None - callback_url: Optional[str] = None - callback_header: Optional[dict[str, str]] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py deleted file mode 100644 index 416f516ef7bf1c..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from __future__ import annotations - -from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py deleted file mode 100644 index 75c7553dbe35c6..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional, Union - -from ...core import BaseModel - -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] - - -class Error(BaseModel): - code: str - message: str - param: Optional[str] = None - - -class Hyperparameters(BaseModel): - n_epochs: Union[str, int, None] = None - - -class FineTuningJob(BaseModel): - id: Optional[str] = None - - request_id: Optional[str] = None - - created_at: Optional[int] = None - - error: Optional[Error] = None - - fine_tuned_model: Optional[str] = None - - finished_at: Optional[int] = None - - hyperparameters: Optional[Hyperparameters] = None - - model: Optional[str] = None - - object: Optional[str] = None - - result_files: list[str] - - status: str - - trained_tokens: Optional[int] = None - - training_file: str - - validation_file: Optional[str] = None - - -class ListOfFineTuningJob(BaseModel): - object: Optional[str] = None - data: list[FineTuningJob] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py deleted file mode 100644 index f996cff11430b0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional, Union - -from ...core import BaseModel - -__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] - - -class Metric(BaseModel): - epoch: Optional[Union[str, int, float]] = None - current_steps: Optional[int] = None - total_steps: Optional[int] = None - elapsed_time: Optional[str] = None - remaining_time: Optional[str] = None - trained_tokens: Optional[int] = None - loss: Optional[Union[str, int, float]] = None - eval_loss: Optional[Union[str, int, float]] = None - acc: Optional[Union[str, int, float]] = None - eval_acc: Optional[Union[str, int, float]] = None - learning_rate: Optional[Union[str, int, float]] = None - - -class JobEvent(BaseModel): - object: Optional[str] = None - id: Optional[str] = None - type: Optional[str] = None - created_at: Optional[int] = None - level: Optional[str] = None - message: Optional[str] = None - data: Optional[Metric] = None - - -class FineTuningJobEvent(BaseModel): - object: Optional[str] = None - data: list[JobEvent] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py deleted file mode 100644 index e1ebc352bc97fd..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Union - -from typing_extensions import TypedDict - -__all__ = ["Hyperparameters"] - - -class Hyperparameters(TypedDict, total=False): - batch_size: Union[Literal["auto"], int] - - learning_rate_multiplier: Union[Literal["auto"], float] - - n_epochs: Union[Literal["auto"], int] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py deleted file mode 100644 index 57d0d2511dbc14..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fine_tuned_models import FineTunedModelsStatus diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py deleted file mode 100644 index b286a5b5774d3d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import ClassVar - -from ....core import PYDANTIC_V2, BaseModel, ConfigDict - -__all__ = ["FineTunedModelsStatus"] - - -class FineTunedModelsStatus(BaseModel): - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=()) - request_id: str # 请求id - model_name: str # 模型名称 - delete_status: str # 删除状态 deleting(删除中), deleted (已删除) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py deleted file mode 100644 index 3bcad0acabd215..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from ..core import BaseModel - -__all__ = ["GeneratedImage", "ImagesResponded"] - - -class GeneratedImage(BaseModel): - b64_json: Optional[str] = None - url: Optional[str] = None - revised_prompt: Optional[str] = None - - -class ImagesResponded(BaseModel): - created: int - data: list[GeneratedImage] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py deleted file mode 100644 index 8c81d703e214a3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .knowledge import KnowledgeInfo -from .knowledge_used import KnowledgeStatistics, KnowledgeUsed - -__all__ = [ - "KnowledgeInfo", - "KnowledgeStatistics", - "KnowledgeUsed", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py deleted file mode 100644 index 59cb41d7124a7f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .document import DocumentData, DocumentFailedInfo, DocumentObject, DocumentSuccessInfo - -__all__ = [ - "DocumentData", - "DocumentObject", - "DocumentSuccessInfo", - "DocumentFailedInfo", -] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py deleted file mode 100644 index 980bc6f4a7c40d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional - -from ....core import BaseModel - -__all__ = ["DocumentData", "DocumentObject", "DocumentSuccessInfo", "DocumentFailedInfo"] - - -class DocumentSuccessInfo(BaseModel): - documentId: Optional[str] = None - """文件id""" - filename: Optional[str] = None - """文件名称""" - - -class DocumentFailedInfo(BaseModel): - failReason: Optional[str] = None - """上传失败的原因,包括:文件格式不支持、文件大小超出限制、知识库容量已满、容量上限为 50 万字。""" - filename: Optional[str] = None - """文件名称""" - documentId: Optional[str] = None - """知识库id""" - - -class DocumentObject(BaseModel): - """文档信息""" - - successInfos: Optional[list[DocumentSuccessInfo]] = None - """上传成功的文件信息""" - failedInfos: Optional[list[DocumentFailedInfo]] = None - """上传失败的文件信息""" - - -class DocumentDataFailInfo(BaseModel): - """失败原因""" - - embedding_code: Optional[int] = ( - None # 失败码 10001:知识不可用,知识库空间已达上限 10002:知识不可用,知识库空间已达上限(字数超出限制) - ) - embedding_msg: Optional[str] = None # 失败原因 - - -class DocumentData(BaseModel): - id: str = None # 知识唯一id - custom_separator: list[str] = None # 切片规则 - sentence_size: str = None # 切片大小 - length: int = None # 文件大小(字节) - word_num: int = None # 文件字数 - name: str = None # 文件名 - url: str = None # 文件下载链接 - embedding_stat: int = None # 0:向量化中 1:向量化完成 2:向量化失败 - failInfo: Optional[DocumentDataFailInfo] = None # 失败原因 向量化失败embedding_stat=2的时候 会有此值 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py deleted file mode 100644 index 509cb3a451af5f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional, TypedDict - -__all__ = ["DocumentEditParams"] - - -class DocumentEditParams(TypedDict): - """ - 知识参数类型定义 - - Attributes: - id (str): 知识ID - knowledge_type (int): 知识类型: - 1:文章知识: 支持pdf,url,docx - 2.问答知识-文档: 支持pdf,url,docx - 3.问答知识-表格: 支持xlsx - 4.商品库-表格: 支持xlsx - 5.自定义: 支持pdf,url,docx - custom_separator (Optional[List[str]]): 当前知识类型为自定义(knowledge_type=5)时的切片规则,默认\n - sentence_size (Optional[int]): 当前知识类型为自定义(knowledge_type=5)时的切片字数,取值范围: 20-2000,默认300 - callback_url (Optional[str]): 回调地址 - callback_header (Optional[dict]): 回调时携带的header - """ - - id: str - knowledge_type: int - custom_separator: Optional[list[str]] - sentence_size: Optional[int] - callback_url: Optional[str] - callback_header: Optional[dict[str, str]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py deleted file mode 100644 index 910c8c045e1b97..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from typing_extensions import TypedDict - - -class DocumentListParams(TypedDict, total=False): - """ - 文件查询参数类型定义 - - Attributes: - purpose (Optional[str]): 文件用途 - knowledge_id (Optional[str]): 当文件用途为 retrieval 时,需要提供查询的知识库ID - page (Optional[int]): 页,默认1 - limit (Optional[int]): 查询文件列表数,默认10 - after (Optional[str]): 查询指定fileID之后的文件列表(当文件用途为 fine-tune 时需要) - order (Optional[str]): 排序规则,可选值['desc', 'asc'],默认desc(当文件用途为 fine-tune 时需要) - """ - - purpose: Optional[str] - knowledge_id: Optional[str] - page: Optional[int] - limit: Optional[int] - after: Optional[str] - order: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py deleted file mode 100644 index acae4fad9ff36b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from ....core import BaseModel -from . import DocumentData - -__all__ = ["DocumentPage"] - - -class DocumentPage(BaseModel): - list: list[DocumentData] - object: str diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py deleted file mode 100644 index bc6f159eb211e5..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = ["KnowledgeInfo"] - - -class KnowledgeInfo(BaseModel): - id: Optional[str] = None - """知识库唯一 id""" - embedding_id: Optional[str] = ( - None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4) - ) - name: Optional[str] = None # 知识库名称 100字限制 - customer_identifier: Optional[str] = None # 用户标识 长度32位以内 - description: Optional[str] = None # 知识库描述 500字限制 - background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky' - icon: Optional[str] = ( - None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子 # noqa: E501 - ) - bucket_id: Optional[str] = None # 桶id 限制32位 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py deleted file mode 100644 index c3da201727c34a..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -from typing_extensions import TypedDict - -__all__ = ["KnowledgeBaseParams"] - - -class KnowledgeBaseParams(TypedDict): - """ - 知识库参数类型定义 - - Attributes: - embedding_id (int): 知识库绑定的向量化模型ID - name (str): 知识库名称,限制100字 - customer_identifier (Optional[str]): 用户标识,长度32位以内 - description (Optional[str]): 知识库描述,限制500字 - background (Optional[Literal['blue', 'red', 'orange', 'purple', 'sky']]): 背景颜色 - icon (Optional[Literal['question', 'book', 'seal', 'wrench', 'tag', 'horn', 'house']]): 知识库图标 - bucket_id (Optional[str]): 桶ID,限制32位 - """ - - embedding_id: int - name: str - customer_identifier: Optional[str] - description: Optional[str] - background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None - icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None - bucket_id: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py deleted file mode 100644 index a221b28e4603be..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing_extensions import TypedDict - -__all__ = ["KnowledgeListParams"] - - -class KnowledgeListParams(TypedDict, total=False): - page: int = 1 - """ 页码,默认 1,第一页 - """ - - size: int = 10 - """每页数量 默认10 - """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py deleted file mode 100644 index e462eddc550d61..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from ...core import BaseModel -from . import KnowledgeInfo - -__all__ = ["KnowledgePage"] - - -class KnowledgePage(BaseModel): - list: list[KnowledgeInfo] - object: str diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py deleted file mode 100644 index cfda7097026c59..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = ["KnowledgeStatistics", "KnowledgeUsed"] - - -class KnowledgeStatistics(BaseModel): - """ - 使用量统计 - """ - - word_num: Optional[int] = None - length: Optional[int] = None - - -class KnowledgeUsed(BaseModel): - used: Optional[KnowledgeStatistics] = None - """已使用量""" - total: Optional[KnowledgeStatistics] = None - """知识库总量""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py deleted file mode 100644 index c9bd60419ce606..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .sensitive_word_check import SensitiveWordCheckRequest - -__all__ = ["SensitiveWordCheckRequest"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py deleted file mode 100644 index 0c37d99e653292..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Optional - -from typing_extensions import TypedDict - - -class SensitiveWordCheckRequest(TypedDict, total=False): - type: Optional[str] - """敏感词类型,当前仅支持ALL""" - status: Optional[str] - """敏感词启用禁用状态 - 启用:ENABLE - 禁用:DISABLE - 备注:默认开启敏感词校验,如果要关闭敏感词校验,需联系商务获取对应权限,否则敏感词禁用不生效。 - """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py deleted file mode 100644 index 62f77344eee56b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .web_search import ( - SearchIntent, - SearchRecommend, - SearchResult, - WebSearch, -) -from .web_search_chunk import WebSearchChunk - -__all__ = ["WebSearch", "SearchIntent", "SearchResult", "SearchRecommend", "WebSearchChunk"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py deleted file mode 100644 index b3a3b26f07ee58..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Union - -from typing_extensions import TypedDict - -__all__ = ["WebSearchParams"] - - -class WebSearchParams(TypedDict): - """ - 工具名:web-search-pro参数类型定义 - - Attributes: - :param model: str, 模型名称 - :param request_id: Optional[str], 请求ID - :param stream: Optional[bool], 是否流式 - :param messages: Union[str, List[str], List[int], object, None], - 包含历史对话上下文的内容,按照 {"role": "user", "content": "你好"} 的json 数组形式进行传参 - 当前版本仅支持 User Message 单轮对话,工具会理解User Message并进行搜索, - 请尽可能传入不带指令格式的用户原始提问,以提高搜索准确率。 - :param scope: Optional[str], 指定搜索范围,全网、学术等,默认全网 - :param location: Optional[str], 指定搜索用户地区 location 提高相关性 - :param recent_days: Optional[int],支持指定返回 N 天(1-30)更新的搜索结果 - - - """ - - model: str - request_id: Optional[str] - stream: Optional[bool] - messages: Union[str, list[str], list[int], object, None] - scope: Optional[str] = None - location: Optional[str] = None - recent_days: Optional[int] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py deleted file mode 100644 index ac9fa3821e979b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = [ - "WebSearch", - "SearchIntent", - "SearchResult", - "SearchRecommend", -] - - -class SearchIntent(BaseModel): - index: int - # 搜索轮次,默认为 0 - query: str - # 搜索优化 query - intent: str - # 判断的意图类型 - keywords: str - # 搜索关键词 - - -class SearchResult(BaseModel): - index: int - # 搜索轮次,默认为 0 - title: str - # 标题 - link: str - # 链接 - content: str - # 内容 - icon: str - # 图标 - media: str - # 来源媒体 - refer: str - # 角标序号 [ref_1] - - -class SearchRecommend(BaseModel): - index: int - # 搜索轮次,默认为 0 - query: str - # 推荐query - - -class WebSearchMessageToolCall(BaseModel): - id: str - search_intent: Optional[SearchIntent] - search_result: Optional[SearchResult] - search_recommend: Optional[SearchRecommend] - type: str - - -class WebSearchMessage(BaseModel): - role: str - tool_calls: Optional[list[WebSearchMessageToolCall]] = None - - -class WebSearchChoice(BaseModel): - index: int - finish_reason: str - message: WebSearchMessage - - -class WebSearch(BaseModel): - created: Optional[int] = None - choices: list[WebSearchChoice] - request_id: Optional[str] = None - id: Optional[str] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py deleted file mode 100644 index 7fb0e02bb58719..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -from ...core import BaseModel -from .web_search import SearchIntent, SearchRecommend, SearchResult - -__all__ = ["WebSearchChunk"] - - -class ChoiceDeltaToolCall(BaseModel): - index: int - id: Optional[str] = None - - search_intent: Optional[SearchIntent] = None - search_result: Optional[SearchResult] = None - search_recommend: Optional[SearchRecommend] = None - type: Optional[str] = None - - -class ChoiceDelta(BaseModel): - role: Optional[str] = None - tool_calls: Optional[list[ChoiceDeltaToolCall]] = None - - -class Choice(BaseModel): - delta: ChoiceDelta - finish_reason: Optional[str] = None - index: int - - -class WebSearchChunk(BaseModel): - id: Optional[str] = None - choices: list[Choice] - created: Optional[int] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py deleted file mode 100644 index b14072b1a771af..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .video_object import VideoObject, VideoResult - -__all__ = ["VideoObject", "VideoResult"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py deleted file mode 100644 index f5489d708e7227..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from typing_extensions import TypedDict - -__all__ = ["VideoCreateParams"] - -from ..sensitive_word_check import SensitiveWordCheckRequest - - -class VideoCreateParams(TypedDict, total=False): - model: str - """模型编码""" - prompt: str - """所需视频的文本描述""" - image_url: str - """所需视频的文本描述""" - sensitive_word_check: Optional[SensitiveWordCheckRequest] - """支持 URL 或者 Base64、传入 image 奖进行图生视频 - * 图片格式: - * 图片大小:""" - request_id: str - """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。""" - - user_id: str - """用户端。""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py deleted file mode 100644 index 85c3844d8a791c..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional - -from ...core import BaseModel - -__all__ = ["VideoObject", "VideoResult"] - - -class VideoResult(BaseModel): - url: str - """视频url""" - cover_image_url: str - """预览图""" - - -class VideoObject(BaseModel): - id: Optional[str] = None - """智谱 AI 开放平台生成的任务订单号,调用请求结果接口时请使用此订单号""" - - model: str - """模型名称""" - - video_result: list[VideoResult] - """视频生成结果""" - - task_status: str - """处理状态,PROCESSING(处理中),SUCCESS(成功),FAIL(失败) - 注:处理中状态需通过查询获取结果""" - - request_id: str - """用户在客户端请求时提交的任务编号或者平台生成的任务编号""" diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 085084ca383552..12b4173fa40270 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,7 +1,8 @@ import random from typing import Any, Union -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI +from zhipuai import ZhipuAI + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/poetry.lock b/api/poetry.lock index efefedfb21cd2a..f1c5d949fee737 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1995,17 +1995,6 @@ dev = ["Sphinx (==5.3.0)", "bump2version (==1.0.1)", "coverage (>=6.2)", "datacl timedelta = ["pytimeparse (>=1.1.7)"] yaml = ["PyYAML (>=5.3)"] -[[package]] -name = "dataclasses" -version = "0.6" -description = "A backport of the dataclasses module for Python 3.6" -optional = false -python-versions = "*" -files = [ - {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"}, - {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"}, -] - [[package]] name = "dataclasses-json" version = "0.6.7" @@ -4965,7 +4954,6 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, - {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, ] @@ -8484,11 +8472,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -10486,20 +10469,21 @@ repair = ["scipy (>=1.6.3)"] [[package]] name = "zhipuai" -version = "1.0.7" +version = "2.1.5.20230904" description = "A SDK library for accessing big model apis from ZhipuAI" optional = false -python-versions = ">=3.6" +python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"}, - {file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"}, + {file = "zhipuai-2.1.5.20230904-py3-none-any.whl", hash = "sha256:8485ca452c2f07fea476fb0666abc8fbbdf1b2e4feeee46a3bb3c1a2b51efccd"}, + {file = "zhipuai-2.1.5.20230904.tar.gz", hash = "sha256:2c19dd796b12e2f19b93d8f9be6fd01e85d3320737a187ebf3c75a9806a7c2b5"}, ] [package.dependencies] -cachetools = "*" -dataclasses = "*" -PyJWT = "*" -requests = "*" +cachetools = ">=4.2.2" +httpx = ">=0.23.0" +pydantic = ">=1.9.0,<3.0" +pydantic-core = ">=2.14.6" +pyjwt = ">=2.8.0,<2.9.0" [[package]] name = "zipp" @@ -10707,4 +10691,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "edb5e3b0d50e84a239224cc77f3f615fdbdd6b504bce5b1075b29363f3054957" +content-hash = "75a7e7eab36b9386c11a3e9808da28102ad20a43a0e8ae08c37594ecf50da02b" diff --git a/api/pyproject.toml b/api/pyproject.toml index dff74750f0f558..cc85ec3af6e9a4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -182,7 +182,7 @@ websocket-client = "~1.7.0" werkzeug = "~3.0.1" xinference-client = "0.15.2" yarl = "~1.9.4" -zhipuai = "1.0.7" +zhipuai = "~2.1.5" # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. ############################################################ From 7a405b86c92fdb70d7210234b42a61c7e3db6e0f Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:26:21 +0800 Subject: [PATCH 21/25] refactor: Refactor the service of retrieval the recommend app (#9302) --- api/services/recommend_app/__init__.py | 0 .../recommend_app/buildin/__init__.py | 0 .../buildin/buildin_retrieval.py | 64 +++++ .../recommend_app/database/__init__.py | 0 .../database/database_retrieval.py | 111 +++++++++ .../recommend_app/recommend_app_base.py | 17 ++ .../recommend_app/recommend_app_factory.py | 23 ++ .../recommend_app/recommend_app_type.py | 7 + api/services/recommend_app/remote/__init__.py | 0 .../recommend_app/remote/remote_retrieval.py | 71 ++++++ api/services/recommended_app_service.py | 235 +----------------- 11 files changed, 302 insertions(+), 226 deletions(-) create mode 100644 api/services/recommend_app/__init__.py create mode 100644 api/services/recommend_app/buildin/__init__.py create mode 100644 api/services/recommend_app/buildin/buildin_retrieval.py create mode 100644 api/services/recommend_app/database/__init__.py create mode 100644 api/services/recommend_app/database/database_retrieval.py create mode 100644 api/services/recommend_app/recommend_app_base.py create mode 100644 api/services/recommend_app/recommend_app_factory.py create mode 100644 api/services/recommend_app/recommend_app_type.py create mode 100644 api/services/recommend_app/remote/__init__.py create mode 100644 api/services/recommend_app/remote/remote_retrieval.py diff --git a/api/services/recommend_app/__init__.py b/api/services/recommend_app/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/buildin/__init__.py b/api/services/recommend_app/buildin/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py new file mode 100644 index 00000000000000..4704d533a950ed --- /dev/null +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -0,0 +1,64 @@ +import json +from os import path +from pathlib import Path +from typing import Optional + +from flask import current_app + +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from buildin, the location is constants/recommended_apps.json + """ + + builtin_data: Optional[dict] = None + + def get_type(self) -> str: + return RecommendAppType.BUILDIN + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_builtin(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_builtin(app_id) + return result + + @classmethod + def _get_builtin_data(cls) -> dict: + """ + Get builtin data. + :return: + """ + if cls.builtin_data: + return cls.builtin_data + + root_path = current_app.root_path + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") + ) + + return cls.builtin_data + + @classmethod + def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: + """ + Fetch recommended apps from builtin. + :param language: language + :return: + """ + builtin_data = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language) + + @classmethod + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from builtin. + :param app_id: App ID + :return: + """ + builtin_data = cls._get_builtin_data() + return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/database/__init__.py b/api/services/recommend_app/database/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py new file mode 100644 index 00000000000000..995d3755bb5b10 --- /dev/null +++ b/api/services/recommend_app/database/database_retrieval.py @@ -0,0 +1,111 @@ +from typing import Optional + +from constants.languages import languages +from extensions.ext_database import db +from models.model import App, RecommendedApp +from services.app_dsl_service import AppDslService +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_db(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_db(app_id) + return result + + def get_type(self) -> str: + return RecommendAppType.DATABASE + + @classmethod + def fetch_recommended_apps_from_db(cls, language: str) -> dict: + """ + Fetch recommended apps from db. + :param language: language + :return: + """ + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) + + if len(recommended_apps) == 0: + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) + + categories = set() + recommended_apps_result = [] + for recommended_app in recommended_apps: + app = recommended_app.app + if not app or not app.is_public: + continue + + site = app.site + if not site: + continue + + recommended_app_result = { + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, + }, + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) + + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + + @classmethod + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from db. + :param app_id: App ID + :return: + """ + # is in public recommended list + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) + + if not recommended_app: + return None + + # get app detail + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: + return None + + return { + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), + } diff --git a/api/services/recommend_app/recommend_app_base.py b/api/services/recommend_app/recommend_app_base.py new file mode 100644 index 00000000000000..00c037710e869c --- /dev/null +++ b/api/services/recommend_app/recommend_app_base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + + +class RecommendAppRetrievalBase(ABC): + """Interface for recommend app retrieval.""" + + @abstractmethod + def get_recommended_apps_and_categories(self, language: str) -> dict: + raise NotImplementedError + + @abstractmethod + def get_recommend_app_detail(self, app_id: str): + raise NotImplementedError + + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError diff --git a/api/services/recommend_app/recommend_app_factory.py b/api/services/recommend_app/recommend_app_factory.py new file mode 100644 index 00000000000000..e53667c0b06dd6 --- /dev/null +++ b/api/services/recommend_app/recommend_app_factory.py @@ -0,0 +1,23 @@ +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class RecommendAppRetrievalFactory: + @staticmethod + def get_recommend_app_factory(mode: str) -> type[RecommendAppRetrievalBase]: + match mode: + case RecommendAppType.REMOTE: + return RemoteRecommendAppRetrieval + case RecommendAppType.DATABASE: + return DatabaseRecommendAppRetrieval + case RecommendAppType.BUILDIN: + return BuildInRecommendAppRetrieval + case _: + raise ValueError(f"invalid fetch recommended apps mode: {mode}") + + @staticmethod + def get_buildin_recommend_app_retrieval(): + return BuildInRecommendAppRetrieval diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py new file mode 100644 index 00000000000000..7ea93b3f64b1d4 --- /dev/null +++ b/api/services/recommend_app/recommend_app_type.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class RecommendAppType(str, Enum): + REMOTE = "remote" + BUILDIN = "builtin" + DATABASE = "db" diff --git a/api/services/recommend_app/remote/__init__.py b/api/services/recommend_app/remote/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py new file mode 100644 index 00000000000000..b0607a21323acb --- /dev/null +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -0,0 +1,71 @@ +import logging +from typing import Optional + +import requests + +from configs import dify_config +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + +logger = logging.getLogger(__name__) + + +class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from dify official + """ + + def get_recommend_app_detail(self, app_id: str): + try: + result = self.fetch_recommended_app_detail_from_dify_official(app_id) + except Exception as e: + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) + return result + + def get_recommended_apps_and_categories(self, language: str) -> dict: + try: + result = self.fetch_recommended_apps_from_dify_official(language) + except Exception as e: + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) + return result + + def get_type(self) -> str: + return RecommendAppType.REMOTE + + @classmethod + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from dify official. + :param app_id: App ID + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps/{app_id}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + return None + + return response.json() + + @classmethod + def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + """ + Fetch recommended apps from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps?language={language}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") + + result = response.json() + + if "categories" in result: + result["categories"] = sorted(result["categories"]) + + return result diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index daec8393d092e5..4660316fcfcf71 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,24 +1,10 @@ -import json -import logging -from os import path -from pathlib import Path from typing import Optional -import requests -from flask import current_app - from configs import dify_config -from constants.languages import languages -from extensions.ext_database import db -from models.model import App, RecommendedApp -from services.app_dsl_service import AppDslService - -logger = logging.getLogger(__name__) +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory class RecommendedAppService: - builtin_data: Optional[dict] = None - @classmethod def get_recommended_apps_and_categories(cls, language: str) -> dict: """ @@ -27,109 +13,17 @@ def get_recommended_apps_and_categories(cls, language: str) -> dict: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == "remote": - try: - result = cls._fetch_recommended_apps_from_dify_official(language) - except Exception as e: - logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") - result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == "db": - result = cls._fetch_recommended_apps_from_db(language) - elif mode == "builtin": - result = cls._fetch_recommended_apps_from_builtin(language) - else: - raise ValueError(f"invalid fetch recommended apps mode: {mode}") - + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommended_apps_and_categories(language) if not result.get("recommended_apps") and language != "en-US": - result = cls._fetch_recommended_apps_from_builtin("en-US") - - return result - - @classmethod - def _fetch_recommended_apps_from_db(cls, language: str) -> dict: - """ - Fetch recommended apps from db. - :param language: language - :return: - """ - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) - .all() - ) - - if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() + result = ( + RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin( + "en-US" + ) ) - categories = set() - recommended_apps_result = [] - for recommended_app in recommended_apps: - app = recommended_app.app - if not app or not app.is_public: - continue - - site = app.site - if not site: - continue - - recommended_app_result = { - "id": recommended_app.id, - "app": { - "id": app.id, - "name": app.name, - "mode": app.mode, - "icon": app.icon, - "icon_background": app.icon_background, - }, - "app_id": recommended_app.app_id, - "description": site.description, - "copyright": site.copyright, - "privacy_policy": site.privacy_policy, - "custom_disclaimer": site.custom_disclaimer, - "category": recommended_app.category, - "position": recommended_app.position, - "is_listed": recommended_app.is_listed, - } - recommended_apps_result.append(recommended_app_result) - - categories.add(recommended_app.category) # add category to categories - - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} - - @classmethod - def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: - """ - Fetch recommended apps from dify official. - :param language: language - :return: - """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f"{domain}/apps?language={language}" - response = requests.get(url, timeout=(3, 10)) - if response.status_code != 200: - raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - - result = response.json() - - if "categories" in result: - result["categories"] = sorted(result["categories"]) - return result - @classmethod - def _fetch_recommended_apps_from_builtin(cls, language: str) -> dict: - """ - Fetch recommended apps from builtin. - :param language: language - :return: - """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) - @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ @@ -138,117 +32,6 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == "remote": - try: - result = cls._fetch_recommended_app_detail_from_dify_official(app_id) - except Exception as e: - logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") - result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == "db": - result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == "builtin": - result = cls._fetch_recommended_app_detail_from_builtin(app_id) - else: - raise ValueError(f"invalid fetch recommended app detail mode: {mode}") - + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommend_app_detail(app_id) return result - - @classmethod - def _fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from dify official. - :param app_id: App ID - :return: - """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f"{domain}/apps/{app_id}" - response = requests.get(url, timeout=(3, 10)) - if response.status_code != 200: - return None - - return response.json() - - @classmethod - def _fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from db. - :param app_id: App ID - :return: - """ - # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() - ) - - if not recommended_app: - return None - - # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() - if not app_model or not app_model.is_public: - return None - - return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), - } - - @classmethod - def _fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from builtin. - :param app_id: App ID - :return: - """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("app_details", {}).get(app_id) - - @classmethod - def _get_builtin_data(cls) -> dict: - """ - Get builtin data. - :return: - """ - if cls.builtin_data: - return cls.builtin_data - - root_path = current_app.root_path - cls.builtin_data = json.loads( - Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") - ) - - return cls.builtin_data - - @classmethod - def fetch_all_recommended_apps_and_export_datas(cls): - """ - Fetch all recommended apps and export datas - :return: - """ - templates = {"recommended_apps": {}, "app_details": {}} - for language in languages: - try: - result = cls._fetch_recommended_apps_from_dify_official(language) - except Exception as e: - logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") - continue - - templates["recommended_apps"][language] = result - - for recommended_app in result.get("recommended_apps"): - app_id = recommended_app.get("app_id") - - # get app detail - app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) - if not app_detail: - continue - - templates["app_details"][app_id] = app_detail - - return templates From 5ee7e03c1b163b8c4bbdcc051fdc2c426810e421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=91=E6=9C=89=E4=B8=80=E6=8A=8A=E5=A6=96=E5=88=80?= Date: Mon, 14 Oct 2024 13:32:13 +0800 Subject: [PATCH 22/25] chore: Optimize operations in Q&A mode (#9274) Co-authored-by: billsyli --- web/app/components/base/popover/index.tsx | 3 ++ .../datasets/create/step-two/index.tsx | 44 ++++++++++++------- .../create/step-two/language-select/index.tsx | 3 ++ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/web/app/components/base/popover/index.tsx b/web/app/components/base/popover/index.tsx index 141ac8ff70d621..1e7ba76269d756 100644 --- a/web/app/components/base/popover/index.tsx +++ b/web/app/components/base/popover/index.tsx @@ -17,6 +17,7 @@ type IPopover = { btnElement?: string | React.ReactNode btnClassName?: string | ((open: boolean) => string) manualClose?: boolean + disabled?: boolean } const timeoutDuration = 100 @@ -30,6 +31,7 @@ export default function CustomPopover({ className, btnClassName, manualClose, + disabled = false, }: IPopover) { const buttonRef = useRef(null) const timeOutRef = useRef(null) @@ -60,6 +62,7 @@ export default function CustomPopover({ > ( (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, ) @@ -200,9 +201,9 @@ const StepTwo = ({ } } - const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => { + const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT, language?: string) => { // eslint-disable-next-line @typescript-eslint/no-use-before-define - const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!) + const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm, language)!) if (segmentationType === SegmentType.CUSTOM) setCustomFileIndexingEstimate(res) else @@ -270,7 +271,7 @@ const StepTwo = ({ } } - const getFileIndexingEstimateParams = (docForm: DocForm): IndexingEstimateParams | undefined => { + const getFileIndexingEstimateParams = (docForm: DocForm, language?: string): IndexingEstimateParams | undefined => { if (dataSourceType === DataSourceType.FILE) { return { info_list: { @@ -282,7 +283,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -295,7 +296,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -308,7 +309,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -483,8 +484,26 @@ const StepTwo = ({ setDocForm(DocForm.TEXT) } + const previewSwitch = async (language?: string) => { + setPreviewSwitched(true) + setIsLanguageSelectDisabled(true) + if (segmentationType === SegmentType.AUTO) + setAutomaticFileIndexingEstimate(null) + else + setCustomFileIndexingEstimate(null) + try { + await fetchFileIndexingEstimate(DocForm.QA, language) + } + finally { + setIsLanguageSelectDisabled(false) + } + } + const handleSelect = (language: string) => { setDocLanguage(language) + // Switch language, re-cutter + if (docForm === DocForm.QA && previewSwitched) + previewSwitch(language) } const changeToEconomicalType = () => { @@ -494,15 +513,6 @@ const StepTwo = ({ } } - const previewSwitch = async () => { - setPreviewSwitched(true) - if (segmentationType === SegmentType.AUTO) - setAutomaticFileIndexingEstimate(null) - else - setCustomFileIndexingEstimate(null) - await fetchFileIndexingEstimate(DocForm.QA) - } - useEffect(() => { // fetch rules if (!isSetting) { @@ -777,7 +787,7 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.QATitle')}
{t('datasetCreation.stepTwo.QALanguage')} - +
@@ -948,7 +958,7 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.previewTitle')}
{docForm === DocForm.QA && !previewSwitched && ( - + )}
diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index f8709c89f3a6bb..fab2bb1c71389d 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -9,16 +9,19 @@ import { languages } from '@/i18n/language' export type ILanguageSelectProps = { currentLanguage: string onSelect: (language: string) => void + disabled?: boolean } const LanguageSelect: FC = ({ currentLanguage, onSelect, + disabled, }) => { return ( {languages.filter(language => language.supported).map(({ prompt_name, name }) => ( From de3c5751db8bca0c0795abbc7f6bddfb31d7779d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=91=E6=9C=89=E4=B8=80=E6=8A=8A=E5=A6=96=E5=88=80?= Date: Mon, 14 Oct 2024 13:32:52 +0800 Subject: [PATCH 23/25] chore: add reopen preview btn (#9279) Co-authored-by: billsyli --- web/app/components/datasets/create/step-two/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index c4c80053b06e57..5d92e30deb8cca 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -585,7 +585,7 @@ const StepTwo = ({
{t('datasetCreation.steps.two')} - {isMobile && ( + {(isMobile || !showPreview) && (